TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

Prioritized Experience Replayのsum-treeの実装

つくりながら学ぶ!深層強化学習のPrioritized Experience Replayの実装は、説明をシンプルにするためReplay Memoryを線形で探索する実装が紹介されていた。
つまり、各transitionのTD誤差を優先度として、0からReplay Memoryの優先度の合計の間で、ランダムに数値を選び、Replay Memoryの格納順に優先度を足していき合計がランダムに選んだ数値超えた要素をサンプリングするという方法だ。

一方、元の論文では、Replay Memoryのサイズが大きい場合、線形探索は速度が問題になるため、sum-treeで実装することで計算量O(log N)にできることが説明されている。
論文には、sum-treeの具体的な説明がなかったので調べてみた。

sum-tree

sum-treeについては、以下のサイトの説明がわかりやすい。
https://jaromiru.com/2016/11/07/lets-make-a-dqn-double-learning-and-prioritized-experience-replay/
f:id:TadaoYamaoka:20190818154503p:plain

2分木の一番下段が、Replay Memoryに格納されたtransitionの優先度にあたり、親ノードが子ノードの合計値になっている。
このように2分木を構築することで、ランダムに選んだ数値から2分木を辿ることで、O(log N)でサンプリングができる。
sum-treeのPythonでの実装もGitHubで公開されている。

上記サイトで公開されていたPythonのコードを使って、実際に優先度に応じた確率でサンプリングされるか確認してみた。

sum-treeによるサンプリングの実験

まずは、サイズを指定してReplay Memoryを作成する。ここでは、サイズ10とした。

from SumTree import SumTree

replay_memory = SumTree(10)


作成後は、要素が追加されていないので、合計は0になる。

replay_memory.total()
Out: 0.0


適当な優先度を付けた10個の要素を追加する。関連付けるデータも格納できるが、ここではNoneを格納する。

P = [6, 48, 31, 26, 49, 43, 93, 74, 79, 13]

for p in P:
    replay_memory.add(p, None)


合計を確認する。

print(sum(P), replay_memory.total())
462 462.0

追加した値の合計と、replay_memoryのtotal()の値が一致している。


1万回サンプリングを行い、各要素が選択された割合を確認する。

from collections import defaultdict
p_sum = defaultdict(int)

for i in range(10000):
    s = random.randint(0, replay_memory.total() - 1)
    _, p, _ = replay_memory.get(s)
    p_sum[int(p)] += 1

for p in P:
    print("{}, {}, {}".format(p, p / sum(P), p_sum[p] / 10000))
6, 0.012987012987012988, 0.0128
48, 0.1038961038961039, 0.0987
31, 0.0670995670995671, 0.0648
26, 0.05627705627705628, 0.0591
49, 0.10606060606060606, 0.1061
43, 0.09307359307359307, 0.0878
93, 0.2012987012987013, 0.2075
74, 0.16017316017316016, 0.1622
79, 0.170995670995671, 0.174
13, 0.02813852813852814, 0.027

優先度に応じた確率(2列目)と、サンプリングされた割合(3列目)がほぼ等しくなっていることが確認できる。