つくりながら学ぶ!深層強化学習のPrioritized Experience Replayの実装は、説明をシンプルにするためReplay Memoryを線形で探索する実装が紹介されていた。
つまり、各transitionのTD誤差を優先度として、0からReplay Memoryの優先度の合計の間で、ランダムに数値を選び、Replay Memoryの格納順に優先度を足していき合計がランダムに選んだ数値超えた要素をサンプリングするという方法だ。
一方、元の論文では、Replay Memoryのサイズが大きい場合、線形探索は速度が問題になるため、sum-treeで実装することで計算量O(log N)にできることが説明されている。
論文には、sum-treeの具体的な説明がなかったので調べてみた。
sum-tree
sum-treeについては、以下のサイトの説明がわかりやすい。
https://jaromiru.com/2016/11/07/lets-make-a-dqn-double-learning-and-prioritized-experience-replay/
2分木の一番下段が、Replay Memoryに格納されたtransitionの優先度にあたり、親ノードが子ノードの合計値になっている。
このように2分木を構築することで、ランダムに選んだ数値から2分木を辿ることで、O(log N)でサンプリングができる。
sum-treeのPythonでの実装もGitHubで公開されている。
上記サイトで公開されていたPythonのコードを使って、実際に優先度に応じた確率でサンプリングされるか確認してみた。
sum-treeによるサンプリングの実験
まずは、サイズを指定してReplay Memoryを作成する。ここでは、サイズ10とした。
from SumTree import SumTree replay_memory = SumTree(10)
作成後は、要素が追加されていないので、合計は0になる。
replay_memory.total()
Out: 0.0
適当な優先度を付けた10個の要素を追加する。関連付けるデータも格納できるが、ここではNoneを格納する。
P = [6, 48, 31, 26, 49, 43, 93, 74, 79, 13] for p in P: replay_memory.add(p, None)
合計を確認する。
print(sum(P), replay_memory.total())
462 462.0
追加した値の合計と、replay_memoryのtotal()の値が一致している。
1万回サンプリングを行い、各要素が選択された割合を確認する。
from collections import defaultdict p_sum = defaultdict(int) for i in range(10000): s = random.randint(0, replay_memory.total() - 1) _, p, _ = replay_memory.get(s) p_sum[int(p)] += 1 for p in P: print("{}, {}, {}".format(p, p / sum(P), p_sum[p] / 10000))
6, 0.012987012987012988, 0.0128 48, 0.1038961038961039, 0.0987 31, 0.0670995670995671, 0.0648 26, 0.05627705627705628, 0.0591 49, 0.10606060606060606, 0.1061 43, 0.09307359307359307, 0.0878 93, 0.2012987012987013, 0.2075 74, 0.16017316017316016, 0.1622 79, 0.170995670995671, 0.174 13, 0.02813852813852814, 0.027
優先度に応じた確率(2列目)と、サンプリングされた割合(3列目)がほぼ等しくなっていることが確認できる。