def solve(N, lessthan):
# init
# f("<", lower=0, upper=1) = 1, f("<", 1, 0) = 0
k = 1
table = [0] * (k + 1)
if lessthan[-1]:
for i in range(k + 1):
table[i] = k - i
else:
for i in range(k + 1):
table[i] = i
return sum(table)
次にN=3の時に正しく動くコードを作る
python
def solve(N, lessthan):
k = 1
table = [0] * (k + 1)
if lessthan[-1]:
for i in range(k + 1):
table[i] = k - i
else:
for i in range(k + 1):
table[i] = i
if N > 2:
k = 2
newtable = [0] * (k + 1)
if lessthan[-2]:
for i in range(k + 1):
for j in range(k - i):
newtable[i] += table[j + i]
else:
for i in range(k + 1):
for j in range(i):
newtable[i] += table[j]
table = newtable
return sum(table)
そして一般のNに拡張する。
これでサンプルコードは全部通る
サブミットするとTLE
python
def solve(N, lessthan):
k = 1
table = [0] * (k + 1)
if lessthan[-1]:
for i in range(k + 1):
table[i] = k - i
else:
for i in range(k + 1):
table[i] = i
for k in range(2, N):
newtable = [0] * (k + 1)
if lessthan[-k]:
for i in range(k + 1):
for j in range(k - i):
newtable[i] += table[j + i]
else:
for i in range(k + 1):
for j in range(i):
newtable[i] += table[j]
table = [x % MOD for x in newtable]
return sum(table) % MOD
ループの中身がようは範囲sumなので累積和を使って高速化する、これでAC
python
for k in range(2, N):
newtable = [0] * (k + 1)
acc = [0] + list(accumulate(table))
if lessthan[-k]:
for i in range(k + 1):
# for j in range(k - i):
# newtable[i] += table[j + i]
newtable[i] += acc[k] - acc[i]
else:
for i in range(k + 1):
# for j in range(i):
# newtable[i] += table[j]
newtable[i] += acc[i]
table = [x % MOD for x in newtable]