This document introduces the deep reinforcement learning model 'A3C' by Japanese.
Original literature is "Asynchronous Methods for Deep Reinforcement Learning" written by V. Mnih, et. al.
2. 今回取り上げるのはこれ
[1] Volodymyr Mnih, Adria` Puigdome`nech Badia, Mehdi
Mirza, Alex Graves, Timothy P. Lillicrap, Tim Harley, David
Silver, and Koray Kavukcuoglu. Asynchronous methods for
deep reinforcement learning. In Proceedings of the 33rd
International Conference on Machine Learning (ICML), pp.
1928–1937, 2016.
Asynchronousな手法によりreplay memoryを廃し、DQNより
高速かつ高精度な学習を達成した!
4. 強化学習の基本①
Li θi( )= E r +γ max
a'
Q s',a';θi−1( )−Q s,a;θi( )( )
2
1-step Q学習の損失関数
actor-criticにおける
目的関数の勾配
1-step Sarsaの損失関数 Li θi( )= E r +γQ s',a';θi−1( )−Q s,a;θi( )( )
2
n-step Q学習の損失関数 Li θi( )= E γk
rt+k
k=0
n
∑ + maxγ
a'
n
Q s',a';θi−1( )−Q s,a;θi( )
⎛
⎝
⎜
⎞
⎠
⎟
2
∇θ J θ( )= E ∇θ logπ at | st;θ( ) Rt −Vπ
st( )( )⎡
⎣
⎤
⎦
r
γ Q s,a;θi( )
Vπ
st( )
:割引率
:報酬
:状態 s で行動 a を取る場合の行動価値関数
:状態 s の価値関数
5. 強化学習の基本②
Li θi( )= E r +γ max
a'
Q s',a';θi−1( )−Q s,a;θi( )( )
2
1-step Q学習の損失関数
これがDQNの場合
L θ( )= Es,a,r,s'≈D r +γ max
a'
Q s',a';θ−
( )−Q s,a;θ( )( )
2
DQNの損失関数
:experience replay memory
:ターゲット・ネットワーク
D
θ−
12. Gorilaのしくみ
A. Nair, et al “Massively parallel methods for deep reinforcement learning.”
In ICML Deep learning Workshop. 2015.
13. Gorilaのしくみ ver.1
共有のreplay memoryを使用
Environment Q Network
Shard 1 Shard 2 Shard K
Q Network
Target
Q Network
DQN Loss
Parameter Server
Environment Q Network
Q Network
Target
Q Network
DQN Loss
・
・
・
ActorのcomputerとLearnerの
computer1つずつで1セットとする
Actor Learner
全部でNセット
replay memoryは1
つを共有する
Replay
Memory
14. Gorilaのしくみ ver.2(bundled mode)
個別のreplay memoryを使用
Environment Q Network
Shard 1 Shard 2 Shard K
Q Network
Target
Q Network
DQN Loss
Replay
Memory
Parameter Server
Environment Q Network
Q Network
Target
Q Network
DQN Loss
Replay
Memory
・
・
・
ActorのcomputerとLearnerの
computer1つずつで1セットとする
Actor Learner
全部でNセット
replay memoryはそれぞれ
のcomputerに配置
15. Gorila(bundled mode)から
asynchronousなDQNへの変更点①
Environment Q Network
Shard 1 Shard 2 Shard K
Q Network
Target
Q Network
DQN Loss
Replay
Memory
Parameter Server
Environment Q Network
Q Network
Target
Q Network
DQN Loss
Replay
Memory
・
・
・
CPU上の1つのスレッドに対応
Actor Learner
replay memoryを廃止
16. Gorila(bundled mode)から
asynchronousなDQNへの変更点②
Environment Q Network
Shard 1 Shard 2 Shard K
Q Network
Target
Q Network
DQN Loss
Parameter Server
Environment Q Network
Q Network
Target
Q Network
DQN Loss
・
・
・
Actor Learner
代わりに勾配を溜め込む
gradients
gradients
17. Gorila(bundled mode)から
asynchronousなDQNへの変更点③
Environment Q Network
Shard 1 Shard 2 Shard K
Q Network
Target
Q Network
DQN Loss
Parameter Server for Q-Network
Environment Q Network
Q Network
Target
Q Network
DQN Loss
・
・
・
Actor Learner
gradients
gradients
Shard 1 Shard 2 Shard K
Parameter Server for Target Q-Network
Target Q-Network用のserverを作る
18. Shard 1 Shard 2 Shard K
Parameter Server for Q-Network
Shard 1 Shard 2 Shard K
Parameter Server for Target Q-Network
AsynchronousなDQNの流れ①
Environment Q Network
Q Network
Target
Q Network
DQN Loss
Environment Q Network
Q Network
Target
Q Network
DQN Loss
・
・
・
Actor Learner
θをコピー をコピー
gradients
gradients
θ−
19. Shard 1 Shard 2 Shard K
Parameter Server for Q-Network
Shard 1 Shard 2 Shard K
Parameter Server for Target Q-Network
AsynchronousなDQNの流れ②
Environment Q Network
Q Network
Target
Q Network
DQN Loss gradients
Environment Q Network
Q Network
Target
Q Network
DQN Loss
・
・
・
Actor Learner
状態 s で行動 a をとり、s’ や r を観測
gradients
20. Shard 1 Shard 2 Shard K
Parameter Server for Q-Network
Shard 1 Shard 2 Shard K
Parameter Server for Target Q-Network
AsynchronousなDQNの流れ③
Environment Q Network
Q Network
Target
Q Network
DQN Loss gradients
Environment Q Network
Q Network
Target
Q Network
DQN Loss
・
・
・
Actor Learner
gradients
L θ( )= Es,a,r,s'≈D r +γ max
a'
Q s',a';θ−
( )−Q s,a;θ( )( )
2
Lossを計算
21. Shard 1 Shard 2 Shard K
Parameter Server for Q-Network
Shard 1 Shard 2 Shard K
Parameter Server for Target Q-Network
AsynchronousなDQNの流れ④
Environment Q Network
Q Network
Target
Q Network
DQN Loss gradients
Environment Q Network
Q Network
Target
Q Network
DQN Loss
・
・
・
Actor Learner
gradients
勾配を溜め込む dθ ← dθ +
∂L θ( )
∂θ
22. Shard 1 Shard 2 Shard K
Parameter Server for Q-Network
Shard 1 Shard 2 Shard K
Parameter Server for Target Q-Network
AsynchronousなDQNの流れ⑤
Environment Q Network
Q Network
Target
Q Network
DQN Loss gradients
Environment Q Network
Q Network
Target
Q Network
DQN Loss
・
・
・
Actor Learner
gradients
定期的に勾配の積算値 を送り学習する dθ