SlideShare ist ein Scribd-Unternehmen logo
1 von 28
Downloaden Sie, um offline zu lesen
Distributed Stochastic Gradient
MCMC
Sungjin Ahn1
Babak Shahbaba2
Max Welling3
1
Department of Computer Science, University of California, Irvine
2
Department of Statistics, University of California, Irvine
3
Machine Learning Group, University of Amsterdam
June 20, 2016
1 / 28
1 おさらい
2 Stochastic Gradient Langevin Dynamics
3 Distributed Inference in LDA
4 Distributed Stochastic Gradient Langevin Dynamics
5 Experiments
2 / 28
ベイズ推定
•
• 尤度: p(x1:n|θ ) =
データ: x1:n = (x1,
∏
x2, ..., xn)
n
i=1
p(xi|θ)
• 事後分布: p(θ |x1:n) ∝ p(x1:n|θ )p(θ )
• 予測分布:
∫
p(x|θ )p(θ |x1:n)dθ
計算コストが高い
→ 予測分布をサンプリング近似法や変分ベイズで近似する.
3 / 28
サンプリング近似法
点推定:一つの推定されたパラメータを用いて予測する.
サンプリング近似:
事後分布からサンプリングされた複数のパラメータの平均によって予
測をする.
θ(s)
∼ p(θ|x1:n),
∫
p(x|θ)p(θ|x1:n)dθ ≈
1
S
S∑
s=1
p(x|θ(s)
)
n が大きいほど,モデルの学習に計算コストがかかる.
→ 近年,Subsampling を用いる手法が多く提案されている.
4 / 28
Abstract
• 確率的最適化に基づいた並列化 MCMC を提案した.
• stochastic gradient MCMC を並列化する際にの問題に対処法
を示した.
• 提案手法を LDA に用いて Wikipedia と Pubmed の大規模コー
パスを学習させたところ,通常の並列化 MCMC での学習時間を
27 時間→ 30 分に軽減できた.(perplexity の収束時間)
5 / 28
Mini-batch-based MCMC
Stochastic Gradient Langevin Dynamics (SGLD)
[Welling & Teh, 2011]
• 単純な Stochastic gradient ascent を拡張する.
• MAP ではなく完全な事後分布からサンプリングするようなベ
イズ推定のアルゴリズムを考える.
• Langevin Dynamics [Neal, 2010] は MCMC の手法の一つ.解
が単なる MAP に陥らないよう,パラメータの更新時にGaussian
noise を加える.
6 / 28
Mini-batch-based MCMC
準備
• データセット X = {x1, x2, ..., xN },パラメータ θ ∈ d
• モデルの同時分布 p(X , θ ) ∝ p(X |θ )p(θ )
• 目的:事後分布 p(θ |X ) からのサンプルを得る.
7 / 28
Stochastic Gradient Ascent
確率的最適化アルゴリズム [Robbins & Monro, 1951]
At iterattion t = 1,2,... :
• データセットから subset {xt1,..., xtn} (n << N) をとる.
• subset を用いて対数事後分布の勾配の近似値を計算する.
∇log p(θt|X) ≈ ∇log p(θt) +
n
N
n∑
i=1
∇log p(xti|θt)
• この値を用いてパラメータの値を更新する.
θt+1 = θt +
εt
2
∇log p(θt) +
n
N
n∑
i=1
∇log p(xti|θt)
8 / 28
Stochastic Gradient Ascent
収束するためのステップサイズ εt の主な条件は次である.
∞∑
t=1
εt = ∞,
∞∑
t=1
ε2
t
< ∞
• パラメータ空間を幅広く行ったり来たりさせるために,ス
テップサイズを小さくしすぎない.
• 局所最適解 (MAP) に収束させるために,ステップサイズを 0にま
で減少させない.
9 / 28
Langevin Dynamics
事後分布 p(θ |X ) に収束するダイナミクスを描く確率微分方程式:
∆θ(t) =
1
2
∇log p(θ(t)|X) + ∆b(t)
ここで,b(t) はブラウン運動である.
• 勾配項は確率が高い場所でより多くの時間を費やすようにダ
イナミクスを促進する.
• ダイナミクスがパラメータ空間全体を探索するようにブラウ
ン運動をノイズとして与える.
10 / 28
Langevin Dynamics
オイラーの有限差分法で離散化すると,
θt+1 = θt +
ε
2
∇log p(θt|X) + ηt ηt ∼ N(0,ε)
• ノイズの総計は勾配のステップサイズの平均
• 有限のステップサイズの場合は離散化誤差が発生する.
• 離散化はメトロポリス・ヘイスティング法(MH 法)のaccept/
reject step により修正される.
• ε → 0 のとき,acceptance rate は 1 になる.
11 / 28
Stochastic Gradient Langevin Dynamics
subset Xt = {xt1,..., xtn} (n << N)
p(θt|Xt) ∝ p(θt)
∏n
i=1
p(xti|θt)
サンプル点は次の式から生成される.
θt+1 = θt +
ε
2
∇log p(θt) +
N
n
n∑
i=1
∇log p(xti|θt)
Stochastic Gradient Acsent
+ ηt
noise
• Gaussian noise:ηt ∼ N(0,εt)
• Annealed step-size:
∑∞
t=1
εt = ∞,
∑∞
t=1
ε2
t
< ∞
• MH 法の accept/reject step は計算コストが高いが,
acceptance probability が1になるので無視できる.
12 / 28
Distributed Inference in LDA
Approximate Distributed LDA (AD-LDA) [Newman et al, 2007]
• MCMC にかかる計算時間を減らすため,それぞれの local
shard に周辺化ギブスサンプリングを行う手法.
• N
1回のサンプリングごとの計算コストが ( S ) まで減少.
• global states との同期により local copy の重みを修正できる.
13 / 28
Distributed Inference in LDA
AD-LDA の欠点:
• データセットのサイズが大きいと,worker を追加しても遅い.
• global states との同期のせいで block-by-the-slowest に苦し
む. block-by-the-slowest:最も遅い worker のタスク完了を他の
worker が待機している状態.
• 並列化した連鎖の実行に大きな overhead(並列化計算のための
処理) がかかる.
14 / 28
Distributed Inference in LDA
Yahoo-LDA (Y-LDA) [Ahmed et al, 2012]
• 非同期での更新により,block-by-the-slowest の解決をした.
• 非同期で無限に更新するとパフォーマンスが悪化する.
[Ho et al, 2013]
15 / 28
Distributed SGLD
Distributed Stochastic Gradient Langevin Dynamics (D-
SGLD):
• SGLD を並列計算し,大規模データに対しても高速なサンプ
リングを可能にしたい.
• local shards からランダムにミニバッチをサンプリングするSGLD
algorithm を提案する.
16 / 28
SGLD on Partitioned Datasets
準備
• sudataset X = {x1,..., xN } を S 個の bset(shard) に分割:
X1,..., XS, X = ∪sXs, N =
∑
s Ns
• データ x が与えられた時の対数尤度 (score function):
g(θ; x) = ∇θ log p(θ; x)
• X からサンプリングされた n 個のデータ点のミニバッチ:X n
shard Xs からサンプリングされたとき:Xs
n
イ テレーション t で X n
s
がサンプリングされたとき:X n
s,t
• score function の合計:G(θ ; X ) =
∑
x∈X g(θ; x)
score function の平均:g¯(θ ; X ) = |X
1
|
G(θ; X)
17 / 28
SGLD on Partitioned Datasets
Proposition
shard s = 1,...,S:
• shard size: Ns(Ns > 0,
∑
s Ns = N)
• 正規化された shard の選択頻度: qs(qs ∈ (0, 1),
∑
s qs = 1)
このとき以下の推定値は SGLD の推定値として妥当である.
¯gd(θ; X n
s
)
de f
=
Ns
Nqs
¯g(θ; X n
s
)
ここで,shard s は,scheduler h( ) からサンプリングされる.ただ
し,頻度 = {q1,...,qS}.
証明は省略(supplementary material があるらしい)
18 / 28
SGLD on Partitioned Datasets
流れ
(1) shard をサンプリングで選ぶ.
s ∼ h( ) = Category(q1,...,qS)
(2) 選んだ shard からミニバッチ X n
s
をサンプリングする.
(3) ミニバッチを使って score 平均 g¯(θ ; Xs
n ) 計算 .
(4) score 平均に N
N
q
s
s
をかけて,重みを修正する.
19 / 28
SGLD on Partitioned Datasets
SGLD update rule
θt+1 ← θt +
εt
2
∇log p(θt) +
Nst
qst
¯g(θt; X n
st
) + νt
• ¯g(θt; X n
st
) の項は step size の補正になっている.
• このアルゴリズムは相対的にサイズが大きい,または他より使用
されていない shard に対して,大きな step をとる.(全ての
data-case が連鎖の混合に等しく用いられている)
20 / 28
Traveling Worker Parallel Chains
問題点
• short-communication-cycle problem:
連鎖はイテレーションごとに新しい worker に遷移する必要があ
るため,伝達サイクルが短い.
• block-by-the-slowest problem:
worker の偏りによる反応遅れのせいで,次にスケジューリングさ
れている worker が処理待ちになる.
21 / 28
Distributed Trajectory Sampling
short-communication-cycle problem への対処
→ trajectory sampling で軽減できる.
• 他の worker に移る代わりに,ひとつの worker を訪れるごとに
連鎖 c で τ 回連続して更新する.
• τ 更新後.最後 (τ 番目) の状態から次の連鎖の worker に移る
• communication-cycle が (n) → (τn) に増加.
• communication overhead は減少.
22 / 28
Adaptive Load Balancing
block-by-the-slowest problem への対処
→ 遅い worker のタスクが終了するまで,速い worker を長めに働
かせる.
• worker 全体の反応時間が可能な限り平均化される.
23 / 28
D-SGLD Psedo Code
24 / 28
Dataset
データセットは以下の2つを使用した.
• Wikipedia corpus:
4.6M articles of approximately 811M tokens in total.
• PubMed Abstract corpus:
8.2M articles of approximately 730M tokens in total.
25 / 28
比較手法
• AD-LDA
• Async-LDA (Y-LDA)
[Ahmed et al., 2012; Smola & Narayanamurthy, 2010]
• SGRLD: Stochastic gradient Riemannian Langevin dynamics
[Patterson & Teh, 2013)]
26 / 28
Perplexity:予測性能
27 / 28
Conclution
本論文では D-SGLD を紹介し,以下のことを示した.
• 適切な修正項を加えることで,提案アルゴリズムは local
subset の偏りを防いだ.
• trajectory sampling により,communication overhead を減少
させた.
• 分散低減法により,収束スピードを早めた.
28 / 28

Weitere ähnliche Inhalte

Was ist angesagt?

比例ハザードモデルはとってもtricky!
比例ハザードモデルはとってもtricky!比例ハザードモデルはとってもtricky!
比例ハザードモデルはとってもtricky!
takehikoihayashi
 

Was ist angesagt? (20)

coordinate descent 法について
coordinate descent 法についてcoordinate descent 法について
coordinate descent 法について
 
強化学習@PyData.Tokyo
強化学習@PyData.Tokyo強化学習@PyData.Tokyo
強化学習@PyData.Tokyo
 
「内積が見えると統計学も見える」第5回 プログラマのための数学勉強会 発表資料
「内積が見えると統計学も見える」第5回 プログラマのための数学勉強会 発表資料 「内積が見えると統計学も見える」第5回 プログラマのための数学勉強会 発表資料
「内積が見えると統計学も見える」第5回 プログラマのための数学勉強会 発表資料
 
社会心理学者のための時系列分析入門_小森
社会心理学者のための時系列分析入門_小森社会心理学者のための時系列分析入門_小森
社会心理学者のための時系列分析入門_小森
 
ようやく分かった!最尤推定とベイズ推定
ようやく分かった!最尤推定とベイズ推定ようやく分かった!最尤推定とベイズ推定
ようやく分かった!最尤推定とベイズ推定
 
充足可能性問題のいろいろ
充足可能性問題のいろいろ充足可能性問題のいろいろ
充足可能性問題のいろいろ
 
比例ハザードモデルはとってもtricky!
比例ハザードモデルはとってもtricky!比例ハザードモデルはとってもtricky!
比例ハザードモデルはとってもtricky!
 
Bayes Independence Test - HSIC と性能を比較する-
Bayes Independence Test - HSIC と性能を比較する-Bayes Independence Test - HSIC と性能を比較する-
Bayes Independence Test - HSIC と性能を比較する-
 
最適輸送の計算アルゴリズムの研究動向
最適輸送の計算アルゴリズムの研究動向最適輸送の計算アルゴリズムの研究動向
最適輸送の計算アルゴリズムの研究動向
 
[DL輪読会]Deep Direct Reinforcement Learning for Financial Signal Representation...
[DL輪読会]Deep Direct Reinforcement Learning for Financial Signal Representation...[DL輪読会]Deep Direct Reinforcement Learning for Financial Signal Representation...
[DL輪読会]Deep Direct Reinforcement Learning for Financial Signal Representation...
 
強化学習における好奇心
強化学習における好奇心強化学習における好奇心
強化学習における好奇心
 
DQNからRainbowまで 〜深層強化学習の最新動向〜
DQNからRainbowまで 〜深層強化学習の最新動向〜DQNからRainbowまで 〜深層強化学習の最新動向〜
DQNからRainbowまで 〜深層強化学習の最新動向〜
 
人間の意思決定を機械学習でモデル化できるか
人間の意思決定を機械学習でモデル化できるか人間の意思決定を機械学習でモデル化できるか
人間の意思決定を機械学習でモデル化できるか
 
マルコフ連鎖モンテカルロ法入門-1
マルコフ連鎖モンテカルロ法入門-1マルコフ連鎖モンテカルロ法入門-1
マルコフ連鎖モンテカルロ法入門-1
 
2値分類・多クラス分類
2値分類・多クラス分類2値分類・多クラス分類
2値分類・多クラス分類
 
猫でも分かるVariational AutoEncoder
猫でも分かるVariational AutoEncoder猫でも分かるVariational AutoEncoder
猫でも分かるVariational AutoEncoder
 
カルマンフィルタ入門
カルマンフィルタ入門カルマンフィルタ入門
カルマンフィルタ入門
 
セミパラメトリック推論の基礎
セミパラメトリック推論の基礎セミパラメトリック推論の基礎
セミパラメトリック推論の基礎
 
最適輸送の解き方
最適輸送の解き方最適輸送の解き方
最適輸送の解き方
 
Variational AutoEncoder
Variational AutoEncoderVariational AutoEncoder
Variational AutoEncoder
 

Andere mochten auch

Andere mochten auch (10)

Stochastic Variational Inference
Stochastic Variational InferenceStochastic Variational Inference
Stochastic Variational Inference
 
Composing graphical models with neural networks for structured representation...
Composing graphical models with neural networks for structured representation...Composing graphical models with neural networks for structured representation...
Composing graphical models with neural networks for structured representation...
 
第3回nips読み会・関西『variational inference foundations and modern methods』
第3回nips読み会・関西『variational inference  foundations and modern methods』第3回nips読み会・関西『variational inference  foundations and modern methods』
第3回nips読み会・関西『variational inference foundations and modern methods』
 
Fisher Vectorによる画像認識
Fisher Vectorによる画像認識Fisher Vectorによる画像認識
Fisher Vectorによる画像認識
 
safe and efficient off policy reinforcement learning
safe and efficient off policy reinforcement learningsafe and efficient off policy reinforcement learning
safe and efficient off policy reinforcement learning
 
Pred net使ってみた
Pred net使ってみたPred net使ってみた
Pred net使ってみた
 
Prml 4.3
Prml 4.3Prml 4.3
Prml 4.3
 
ICML読み会2016@早稲田
ICML読み会2016@早稲田ICML読み会2016@早稲田
ICML読み会2016@早稲田
 
優れた研究論文の書き方―7つの提案
優れた研究論文の書き方―7つの提案優れた研究論文の書き方―7つの提案
優れた研究論文の書き方―7つの提案
 
increasing the action gap - new operators for reinforcement learning
increasing the action gap - new operators for reinforcement learningincreasing the action gap - new operators for reinforcement learning
increasing the action gap - new operators for reinforcement learning
 

Ähnlich wie Distributed Stochastic Gradient MCMC

070 統計的推測 母集団と推定
070 統計的推測 母集団と推定070 統計的推測 母集団と推定
070 統計的推測 母集団と推定
t2tarumi
 
K070k80 点推定 区間推定
K070k80 点推定 区間推定K070k80 点推定 区間推定
K070k80 点推定 区間推定
t2tarumi
 
K060 中心極限定理clt
K060 中心極限定理cltK060 中心極限定理clt
K060 中心極限定理clt
t2tarumi
 
パターン認識第9章 学習ベクトル量子化
パターン認識第9章 学習ベクトル量子化パターン認識第9章 学習ベクトル量子化
パターン認識第9章 学習ベクトル量子化
Miyoshi Yuya
 
PRML ベイズロジスティック回帰
PRML ベイズロジスティック回帰PRML ベイズロジスティック回帰
PRML ベイズロジスティック回帰
hagino 3000
 

Ähnlich wie Distributed Stochastic Gradient MCMC (20)

逐次モンテカルロ法の基礎
逐次モンテカルロ法の基礎逐次モンテカルロ法の基礎
逐次モンテカルロ法の基礎
 
PRML 8.4-8.4.3
PRML 8.4-8.4.3 PRML 8.4-8.4.3
PRML 8.4-8.4.3
 
量子アニーリングを用いたクラスタ分析
量子アニーリングを用いたクラスタ分析量子アニーリングを用いたクラスタ分析
量子アニーリングを用いたクラスタ分析
 
070 統計的推測 母集団と推定
070 統計的推測 母集団と推定070 統計的推測 母集団と推定
070 統計的推測 母集団と推定
 
K070k80 点推定 区間推定
K070k80 点推定 区間推定K070k80 点推定 区間推定
K070k80 点推定 区間推定
 
ウィナーフィルタと適応フィルタ
ウィナーフィルタと適応フィルタウィナーフィルタと適応フィルタ
ウィナーフィルタと適応フィルタ
 
第8章 ガウス過程回帰による異常検知
第8章 ガウス過程回帰による異常検知第8章 ガウス過程回帰による異常検知
第8章 ガウス過程回帰による異常検知
 
CMSI計算科学技術特論A (2015) 第10回 行列計算における高速アルゴリズム1
CMSI計算科学技術特論A (2015) 第10回 行列計算における高速アルゴリズム1CMSI計算科学技術特論A (2015) 第10回 行列計算における高速アルゴリズム1
CMSI計算科学技術特論A (2015) 第10回 行列計算における高速アルゴリズム1
 
MLaPP 24章 「マルコフ連鎖モンテカルロ法 (MCMC) による推論」
MLaPP 24章 「マルコフ連鎖モンテカルロ法 (MCMC) による推論」MLaPP 24章 「マルコフ連鎖モンテカルロ法 (MCMC) による推論」
MLaPP 24章 「マルコフ連鎖モンテカルロ法 (MCMC) による推論」
 
2013.12.26 prml勉強会 線形回帰モデル3.2~3.4
2013.12.26 prml勉強会 線形回帰モデル3.2~3.42013.12.26 prml勉強会 線形回帰モデル3.2~3.4
2013.12.26 prml勉強会 線形回帰モデル3.2~3.4
 
K060 中心極限定理clt
K060 中心極限定理cltK060 中心極限定理clt
K060 中心極限定理clt
 
パターン認識第9章 学習ベクトル量子化
パターン認識第9章 学習ベクトル量子化パターン認識第9章 学習ベクトル量子化
パターン認識第9章 学習ベクトル量子化
 
Sparse estimation tutorial 2014
Sparse estimation tutorial 2014Sparse estimation tutorial 2014
Sparse estimation tutorial 2014
 
PRML chapter7
PRML chapter7PRML chapter7
PRML chapter7
 
Prml sec6
Prml sec6Prml sec6
Prml sec6
 
2013 03 25
2013 03 252013 03 25
2013 03 25
 
自動微分変分ベイズ法の紹介
自動微分変分ベイズ法の紹介自動微分変分ベイズ法の紹介
自動微分変分ベイズ法の紹介
 
Shunsuke Horii
Shunsuke HoriiShunsuke Horii
Shunsuke Horii
 
PRML ベイズロジスティック回帰
PRML ベイズロジスティック回帰PRML ベイズロジスティック回帰
PRML ベイズロジスティック回帰
 
8.4 グラフィカルモデルによる推論
8.4 グラフィカルモデルによる推論8.4 グラフィカルモデルによる推論
8.4 グラフィカルモデルによる推論
 

Kürzlich hochgeladen

Kürzlich hochgeladen (12)

論文紹介:Selective Structured State-Spaces for Long-Form Video Understanding
論文紹介:Selective Structured State-Spaces for Long-Form Video Understanding論文紹介:Selective Structured State-Spaces for Long-Form Video Understanding
論文紹介:Selective Structured State-Spaces for Long-Form Video Understanding
 
論文紹介:Video-GroundingDINO: Towards Open-Vocabulary Spatio-Temporal Video Groun...
論文紹介:Video-GroundingDINO: Towards Open-Vocabulary Spatio-Temporal Video Groun...論文紹介:Video-GroundingDINO: Towards Open-Vocabulary Spatio-Temporal Video Groun...
論文紹介:Video-GroundingDINO: Towards Open-Vocabulary Spatio-Temporal Video Groun...
 
Amazon SES を勉強してみる その32024/04/26の勉強会で発表されたものです。
Amazon SES を勉強してみる その32024/04/26の勉強会で発表されたものです。Amazon SES を勉強してみる その32024/04/26の勉強会で発表されたものです。
Amazon SES を勉強してみる その32024/04/26の勉強会で発表されたものです。
 
LoRaWANスマート距離検出センサー DS20L カタログ LiDARデバイス
LoRaWANスマート距離検出センサー  DS20L  カタログ  LiDARデバイスLoRaWANスマート距離検出センサー  DS20L  カタログ  LiDARデバイス
LoRaWANスマート距離検出センサー DS20L カタログ LiDARデバイス
 
NewSQLの可用性構成パターン(OCHaCafe Season 8 #4 発表資料)
NewSQLの可用性構成パターン(OCHaCafe Season 8 #4 発表資料)NewSQLの可用性構成パターン(OCHaCafe Season 8 #4 発表資料)
NewSQLの可用性構成パターン(OCHaCafe Season 8 #4 発表資料)
 
LoRaWAN スマート距離検出デバイスDS20L日本語マニュアル
LoRaWAN スマート距離検出デバイスDS20L日本語マニュアルLoRaWAN スマート距離検出デバイスDS20L日本語マニュアル
LoRaWAN スマート距離検出デバイスDS20L日本語マニュアル
 
Utilizing Ballerina for Cloud Native Integrations
Utilizing Ballerina for Cloud Native IntegrationsUtilizing Ballerina for Cloud Native Integrations
Utilizing Ballerina for Cloud Native Integrations
 
Observabilityは従来型の監視と何が違うのか(キンドリルジャパン社内勉強会:2022年10月27日発表)
Observabilityは従来型の監視と何が違うのか(キンドリルジャパン社内勉強会:2022年10月27日発表)Observabilityは従来型の監視と何が違うのか(キンドリルジャパン社内勉強会:2022年10月27日発表)
Observabilityは従来型の監視と何が違うのか(キンドリルジャパン社内勉強会:2022年10月27日発表)
 
Amazon SES を勉強してみる その22024/04/26の勉強会で発表されたものです。
Amazon SES を勉強してみる その22024/04/26の勉強会で発表されたものです。Amazon SES を勉強してみる その22024/04/26の勉強会で発表されたものです。
Amazon SES を勉強してみる その22024/04/26の勉強会で発表されたものです。
 
新人研修 後半 2024/04/26の勉強会で発表されたものです。
新人研修 後半        2024/04/26の勉強会で発表されたものです。新人研修 後半        2024/04/26の勉強会で発表されたものです。
新人研修 後半 2024/04/26の勉強会で発表されたものです。
 
論文紹介: The Surprising Effectiveness of PPO in Cooperative Multi-Agent Games
論文紹介: The Surprising Effectiveness of PPO in Cooperative Multi-Agent Games論文紹介: The Surprising Effectiveness of PPO in Cooperative Multi-Agent Games
論文紹介: The Surprising Effectiveness of PPO in Cooperative Multi-Agent Games
 
知識ゼロの営業マンでもできた!超速で初心者を脱する、悪魔的学習ステップ3選.pptx
知識ゼロの営業マンでもできた!超速で初心者を脱する、悪魔的学習ステップ3選.pptx知識ゼロの営業マンでもできた!超速で初心者を脱する、悪魔的学習ステップ3選.pptx
知識ゼロの営業マンでもできた!超速で初心者を脱する、悪魔的学習ステップ3選.pptx
 

Distributed Stochastic Gradient MCMC

  • 1. Distributed Stochastic Gradient MCMC Sungjin Ahn1 Babak Shahbaba2 Max Welling3 1 Department of Computer Science, University of California, Irvine 2 Department of Statistics, University of California, Irvine 3 Machine Learning Group, University of Amsterdam June 20, 2016 1 / 28
  • 2. 1 おさらい 2 Stochastic Gradient Langevin Dynamics 3 Distributed Inference in LDA 4 Distributed Stochastic Gradient Langevin Dynamics 5 Experiments 2 / 28
  • 3. ベイズ推定 • • 尤度: p(x1:n|θ ) = データ: x1:n = (x1, ∏ x2, ..., xn) n i=1 p(xi|θ) • 事後分布: p(θ |x1:n) ∝ p(x1:n|θ )p(θ ) • 予測分布: ∫ p(x|θ )p(θ |x1:n)dθ 計算コストが高い → 予測分布をサンプリング近似法や変分ベイズで近似する. 3 / 28
  • 5. Abstract • 確率的最適化に基づいた並列化 MCMC を提案した. • stochastic gradient MCMC を並列化する際にの問題に対処法 を示した. • 提案手法を LDA に用いて Wikipedia と Pubmed の大規模コー パスを学習させたところ,通常の並列化 MCMC での学習時間を 27 時間→ 30 分に軽減できた.(perplexity の収束時間) 5 / 28
  • 6. Mini-batch-based MCMC Stochastic Gradient Langevin Dynamics (SGLD) [Welling & Teh, 2011] • 単純な Stochastic gradient ascent を拡張する. • MAP ではなく完全な事後分布からサンプリングするようなベ イズ推定のアルゴリズムを考える. • Langevin Dynamics [Neal, 2010] は MCMC の手法の一つ.解 が単なる MAP に陥らないよう,パラメータの更新時にGaussian noise を加える. 6 / 28
  • 7. Mini-batch-based MCMC 準備 • データセット X = {x1, x2, ..., xN },パラメータ θ ∈ d • モデルの同時分布 p(X , θ ) ∝ p(X |θ )p(θ ) • 目的:事後分布 p(θ |X ) からのサンプルを得る. 7 / 28
  • 8. Stochastic Gradient Ascent 確率的最適化アルゴリズム [Robbins & Monro, 1951] At iterattion t = 1,2,... : • データセットから subset {xt1,..., xtn} (n << N) をとる. • subset を用いて対数事後分布の勾配の近似値を計算する. ∇log p(θt|X) ≈ ∇log p(θt) + n N n∑ i=1 ∇log p(xti|θt) • この値を用いてパラメータの値を更新する. θt+1 = θt + εt 2 ∇log p(θt) + n N n∑ i=1 ∇log p(xti|θt) 8 / 28
  • 9. Stochastic Gradient Ascent 収束するためのステップサイズ εt の主な条件は次である. ∞∑ t=1 εt = ∞, ∞∑ t=1 ε2 t < ∞ • パラメータ空間を幅広く行ったり来たりさせるために,ス テップサイズを小さくしすぎない. • 局所最適解 (MAP) に収束させるために,ステップサイズを 0にま で減少させない. 9 / 28
  • 10. Langevin Dynamics 事後分布 p(θ |X ) に収束するダイナミクスを描く確率微分方程式: ∆θ(t) = 1 2 ∇log p(θ(t)|X) + ∆b(t) ここで,b(t) はブラウン運動である. • 勾配項は確率が高い場所でより多くの時間を費やすようにダ イナミクスを促進する. • ダイナミクスがパラメータ空間全体を探索するようにブラウ ン運動をノイズとして与える. 10 / 28
  • 11. Langevin Dynamics オイラーの有限差分法で離散化すると, θt+1 = θt + ε 2 ∇log p(θt|X) + ηt ηt ∼ N(0,ε) • ノイズの総計は勾配のステップサイズの平均 • 有限のステップサイズの場合は離散化誤差が発生する. • 離散化はメトロポリス・ヘイスティング法(MH 法)のaccept/ reject step により修正される. • ε → 0 のとき,acceptance rate は 1 になる. 11 / 28
  • 12. Stochastic Gradient Langevin Dynamics subset Xt = {xt1,..., xtn} (n << N) p(θt|Xt) ∝ p(θt) ∏n i=1 p(xti|θt) サンプル点は次の式から生成される. θt+1 = θt + ε 2 ∇log p(θt) + N n n∑ i=1 ∇log p(xti|θt) Stochastic Gradient Acsent + ηt noise • Gaussian noise:ηt ∼ N(0,εt) • Annealed step-size: ∑∞ t=1 εt = ∞, ∑∞ t=1 ε2 t < ∞ • MH 法の accept/reject step は計算コストが高いが, acceptance probability が1になるので無視できる. 12 / 28
  • 13. Distributed Inference in LDA Approximate Distributed LDA (AD-LDA) [Newman et al, 2007] • MCMC にかかる計算時間を減らすため,それぞれの local shard に周辺化ギブスサンプリングを行う手法. • N 1回のサンプリングごとの計算コストが ( S ) まで減少. • global states との同期により local copy の重みを修正できる. 13 / 28
  • 14. Distributed Inference in LDA AD-LDA の欠点: • データセットのサイズが大きいと,worker を追加しても遅い. • global states との同期のせいで block-by-the-slowest に苦し む. block-by-the-slowest:最も遅い worker のタスク完了を他の worker が待機している状態. • 並列化した連鎖の実行に大きな overhead(並列化計算のための 処理) がかかる. 14 / 28
  • 15. Distributed Inference in LDA Yahoo-LDA (Y-LDA) [Ahmed et al, 2012] • 非同期での更新により,block-by-the-slowest の解決をした. • 非同期で無限に更新するとパフォーマンスが悪化する. [Ho et al, 2013] 15 / 28
  • 16. Distributed SGLD Distributed Stochastic Gradient Langevin Dynamics (D- SGLD): • SGLD を並列計算し,大規模データに対しても高速なサンプ リングを可能にしたい. • local shards からランダムにミニバッチをサンプリングするSGLD algorithm を提案する. 16 / 28
  • 17. SGLD on Partitioned Datasets 準備 • sudataset X = {x1,..., xN } を S 個の bset(shard) に分割: X1,..., XS, X = ∪sXs, N = ∑ s Ns • データ x が与えられた時の対数尤度 (score function): g(θ; x) = ∇θ log p(θ; x) • X からサンプリングされた n 個のデータ点のミニバッチ:X n shard Xs からサンプリングされたとき:Xs n イ テレーション t で X n s がサンプリングされたとき:X n s,t • score function の合計:G(θ ; X ) = ∑ x∈X g(θ; x) score function の平均:g¯(θ ; X ) = |X 1 | G(θ; X) 17 / 28
  • 18. SGLD on Partitioned Datasets Proposition shard s = 1,...,S: • shard size: Ns(Ns > 0, ∑ s Ns = N) • 正規化された shard の選択頻度: qs(qs ∈ (0, 1), ∑ s qs = 1) このとき以下の推定値は SGLD の推定値として妥当である. ¯gd(θ; X n s ) de f = Ns Nqs ¯g(θ; X n s ) ここで,shard s は,scheduler h( ) からサンプリングされる.ただ し,頻度 = {q1,...,qS}. 証明は省略(supplementary material があるらしい) 18 / 28
  • 19. SGLD on Partitioned Datasets 流れ (1) shard をサンプリングで選ぶ. s ∼ h( ) = Category(q1,...,qS) (2) 選んだ shard からミニバッチ X n s をサンプリングする. (3) ミニバッチを使って score 平均 g¯(θ ; Xs n ) 計算 . (4) score 平均に N N q s s をかけて,重みを修正する. 19 / 28
  • 20. SGLD on Partitioned Datasets SGLD update rule θt+1 ← θt + εt 2 ∇log p(θt) + Nst qst ¯g(θt; X n st ) + νt • ¯g(θt; X n st ) の項は step size の補正になっている. • このアルゴリズムは相対的にサイズが大きい,または他より使用 されていない shard に対して,大きな step をとる.(全ての data-case が連鎖の混合に等しく用いられている) 20 / 28
  • 21. Traveling Worker Parallel Chains 問題点 • short-communication-cycle problem: 連鎖はイテレーションごとに新しい worker に遷移する必要があ るため,伝達サイクルが短い. • block-by-the-slowest problem: worker の偏りによる反応遅れのせいで,次にスケジューリングさ れている worker が処理待ちになる. 21 / 28
  • 22. Distributed Trajectory Sampling short-communication-cycle problem への対処 → trajectory sampling で軽減できる. • 他の worker に移る代わりに,ひとつの worker を訪れるごとに 連鎖 c で τ 回連続して更新する. • τ 更新後.最後 (τ 番目) の状態から次の連鎖の worker に移る • communication-cycle が (n) → (τn) に増加. • communication overhead は減少. 22 / 28
  • 23. Adaptive Load Balancing block-by-the-slowest problem への対処 → 遅い worker のタスクが終了するまで,速い worker を長めに働 かせる. • worker 全体の反応時間が可能な限り平均化される. 23 / 28
  • 25. Dataset データセットは以下の2つを使用した. • Wikipedia corpus: 4.6M articles of approximately 811M tokens in total. • PubMed Abstract corpus: 8.2M articles of approximately 730M tokens in total. 25 / 28
  • 26. 比較手法 • AD-LDA • Async-LDA (Y-LDA) [Ahmed et al., 2012; Smola & Narayanamurthy, 2010] • SGRLD: Stochastic gradient Riemannian Langevin dynamics [Patterson & Teh, 2013)] 26 / 28
  • 28. Conclution 本論文では D-SGLD を紹介し,以下のことを示した. • 適切な修正項を加えることで,提案アルゴリズムは local subset の偏りを防いだ. • trajectory sampling により,communication overhead を減少 させた. • 分散低減法により,収束スピードを早めた. 28 / 28