1. Distributed Stochastic Gradient
MCMC
Sungjin Ahn1
Babak Shahbaba2
Max Welling3
1
Department of Computer Science, University of California, Irvine
2
Department of Statistics, University of California, Irvine
3
Machine Learning Group, University of Amsterdam
June 20, 2016
1 / 28
13. Distributed Inference in LDA
Approximate Distributed LDA (AD-LDA) [Newman et al, 2007]
• MCMC にかかる計算時間を減らすため,それぞれの local
shard に周辺化ギブスサンプリングを行う手法.
• N
1回のサンプリングごとの計算コストが ( S ) まで減少.
• global states との同期により local copy の重みを修正できる.
13 / 28
14. Distributed Inference in LDA
AD-LDA の欠点:
• データセットのサイズが大きいと,worker を追加しても遅い.
• global states との同期のせいで block-by-the-slowest に苦し
む. block-by-the-slowest:最も遅い worker のタスク完了を他の
worker が待機している状態.
• 並列化した連鎖の実行に大きな overhead(並列化計算のための
処理) がかかる.
14 / 28
15. Distributed Inference in LDA
Yahoo-LDA (Y-LDA) [Ahmed et al, 2012]
• 非同期での更新により,block-by-the-slowest の解決をした.
• 非同期で無限に更新するとパフォーマンスが悪化する.
[Ho et al, 2013]
15 / 28
17. SGLD on Partitioned Datasets
準備
• sudataset X = {x1,..., xN } を S 個の bset(shard) に分割:
X1,..., XS, X = ∪sXs, N =
∑
s Ns
• データ x が与えられた時の対数尤度 (score function):
g(θ; x) = ∇θ log p(θ; x)
• X からサンプリングされた n 個のデータ点のミニバッチ:X n
shard Xs からサンプリングされたとき:Xs
n
イ テレーション t で X n
s
がサンプリングされたとき:X n
s,t
• score function の合計:G(θ ; X ) =
∑
x∈X g(θ; x)
score function の平均:g¯(θ ; X ) = |X
1
|
G(θ; X)
17 / 28
18. SGLD on Partitioned Datasets
Proposition
shard s = 1,...,S:
• shard size: Ns(Ns > 0,
∑
s Ns = N)
• 正規化された shard の選択頻度: qs(qs ∈ (0, 1),
∑
s qs = 1)
このとき以下の推定値は SGLD の推定値として妥当である.
¯gd(θ; X n
s
)
de f
=
Ns
Nqs
¯g(θ; X n
s
)
ここで,shard s は,scheduler h( ) からサンプリングされる.ただ
し,頻度 = {q1,...,qS}.
証明は省略(supplementary material があるらしい)
18 / 28
19. SGLD on Partitioned Datasets
流れ
(1) shard をサンプリングで選ぶ.
s ∼ h( ) = Category(q1,...,qS)
(2) 選んだ shard からミニバッチ X n
s
をサンプリングする.
(3) ミニバッチを使って score 平均 g¯(θ ; Xs
n ) 計算 .
(4) score 平均に N
N
q
s
s
をかけて,重みを修正する.
19 / 28
20. SGLD on Partitioned Datasets
SGLD update rule
θt+1 ← θt +
εt
2
∇log p(θt) +
Nst
qst
¯g(θt; X n
st
) + νt
• ¯g(θt; X n
st
) の項は step size の補正になっている.
• このアルゴリズムは相対的にサイズが大きい,または他より使用
されていない shard に対して,大きな step をとる.(全ての
data-case が連鎖の混合に等しく用いられている)
20 / 28