SlideShare ist ein Scribd-Unternehmen logo
1 von 14
Adversarial Training
to avoid overfitting
NBME top#2 Solution and Discussion
https://www.kaggle.com/competitions/nbme-score-
clinical-patient-notes/discussion/323085
Feedback top#1で活用されたが,NBMEでは
“Although its CV score was quite higher than the one I selected above, both its public LB score and private LB score were lower.
It seems that my way of doing pseudo labeling was better. It may be that being quite new to these techniques I didn't tune
them correctly. I will try them in future competitions for sure.”
とあるので汎用性についてはさらなる実装と議論が必要か。
Adversarial Training
Inputs Perturbation
into “Local” worst-case
Weights Perturbation
into “Global” worst-case
Gradient-based Adversary Not gradient-based
Need Labels Not need Labels
FGM, SiFT
VAT, TRADES,
SMART
MART
AWP
Adversarial Training
https://arxiv.org/abs/1412.6572 : Goodfellow IJ et al., 2015, ICLR 2015
 摂動を加えた入力の中でモデルにとってhigh confidenceに間違えるようなもの = Adversarial Examples
 Adversarial Examplesを作成しながらモデル精度を高める = Adversarial Training
 ランダムノイズを加える点では一般的なaugmentationと捉えられるが,その中でもよりadversarialなもの
Adversarial Example
自信をもって間違えている
Perturbation is:
𝜂 = 𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝑦
loss backwardの時点で入力xについても自動微分が得られる
lossを大きくする方向 (𝑠𝑖𝑔𝑛)に微小(𝜖)動かす
参考) https://ai-scholar.tech/articles/adversarial-perturbation/Earlystopping
目的関数に追加: 𝐽 𝜃, 𝑥, 𝑦 = 𝛼𝐽 𝜃, 𝑥, 𝑦 + 1 − 𝛼 𝐽 𝜃, 𝑥 + 𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝑦 , 𝑦
VAT; Virtual Adversarial Training
https://arxiv.org/pdf/1507.00677.pdf : Miyato T et al., 2016, ICLR 2016
https://arxiv.org/pdf/1704.03976v2.pdf : Miyato T et al., 2018
https://arxiv.org/pdf/1605.07725.pdf : Miyato T et al., 2021
 モデル分散の平滑化を目的とした正則化として働く
 当初のAdversarial Examples作成にはLabel (𝑦)情報が必要 (𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝒚 なので)だが,VATではPerturbationを与えた
ときにmodel outputがどれくらい動くか(LDS)を用いるのでLabel不要
 勾配情報を用いるとsignによってadversarial directionが決まるが,VATではLDSを探索することでdirectionを決めるた
め,’virtual‘ adversarial trainingという名前がついた
𝑥 𝑛
𝑖𝑛𝑝𝑢𝑡 𝑠𝑝𝑎𝑐𝑒
𝑝 𝑦|𝑥 𝑛
, 𝜃
model output
(e.g. classification
prediction)
なめらかな予測曲線
= 過学習しづらい (正しい学習)
過学習の傾向
観測されたデータ
0.9
ちょっとノイズが入ると予測が
大きく変動する
ちょっと入力を動かしたときの予測自体 (精度ではない)が変動
するかをKL Div.で定量したものをLocal Distributional
Smoothing (LDS)とした。
∆𝐾𝐿 𝑟, 𝑥 𝑛 , 𝜃 ≡ 𝐾𝐿 𝑝 𝑦|𝑥 𝑛 , 𝜃 ||𝑝 𝑦|𝑥 𝑛 + 𝑟, 𝜃
𝑟𝑣−𝑎𝑑𝑣
𝑛
≡ 𝑎𝑟𝑔 max
𝑟
∆𝐾𝐿 𝑟, 𝑥 𝑛
, 𝜃 ; 𝑟 2 ≤ 𝜖
𝐿𝐷𝑆 𝑥 𝑛
, 𝜃 ≡ −∆𝐾𝐿 𝑟𝑣−𝑎𝑑𝑣
𝑛
, 𝑥 𝑛
, 𝜃
https://github.com/tensorflow/models/tree/master/research/adversarial_text
LDSを正則化項として目的関数に追加
詰まるところ, 𝑟𝑣−𝑎𝑑𝑣
𝑛
を決めるのが大変
𝑟𝑣−𝑎𝑑𝑣
𝑛
≡ 𝑎𝑟𝑔 max
𝑟
∆𝐾𝐿 𝑟, 𝑥 𝑛
, 𝜃 ; 𝑟 2 ≤ 𝜖
これ自身も学習で求める
𝑖𝑛𝑝𝑢𝑡 に対して𝑟𝑎𝑛𝑑𝑜𝑚 𝑣𝑒𝑐𝑡𝑜𝑟 𝑑を初期化して以下SGDによって更新
𝑑 ← 𝛻𝑟𝐾𝐿 𝑟, 𝑥, 𝜃
𝑟=𝜉𝑑
𝑤ℎ𝑒𝑟𝑒 𝑣 =
𝑣
𝑣 2
⋯
⋯
普通にloss求める
LDS求める
⋯
一旦勾配計算とめて
⋯
⋯
adversarial_losses.py
train_classifier.py
KL Div.に対するdの勾配を求めて
𝑟𝑣−𝑎𝑑𝑣を得る
https://github.com/tensorflow/models/tree/master/research/adversarial_text
VAT-Pytorchのイメージ
mt_dnn/perturbation.py
TRADES
https://arxiv.org/abs/1901.08573 : Zhang H et al., 2019, ICML 2019
𝜌𝑇𝑅𝐴𝐷𝐸𝑆 𝑤 =
1
𝑛
𝑖=1
𝑛
𝐶𝐸 𝑓𝑤 𝑥𝑖 , 𝑦𝑖 + 𝛽 𝑚𝑎𝑥𝐾𝐿 𝑓𝑤 𝑥𝑖 ||𝑓𝑤 𝑥𝑖
′
 正直VATとの違いがあまり分からない
 コードがとても使いやすい
https://github.com/yaodongyu/TRADES
from trades import trades_loss
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# calculate robust loss - TRADES loss
loss = trades_loss(model=model,
x_natural=data,
y=target,
optimizer=optimizer,
step_size=args.step_size,
epsilon=args.epsilon,
perturb_steps=args.num_steps,
beta=args.beta,
distance='l_inf')
loss.backward()
optimizer.step()
MART
https://openreview.net/forum?id=rklOg6EFwS : Wang Y et al., 2019, ICLR 2020
https://github.com/YisenWang/MART
𝜌𝑀𝐴𝑅𝑇 𝑤 =
1
𝑛
𝑖=1
𝑛
𝐵𝐶𝐸 𝑓𝑤 𝑥𝑖
′
, 𝑦𝑖 + 𝜆 𝐾𝐿 𝑓𝑤 𝑥𝑖 ||𝑓𝑤 𝑥𝑖
′
∙ 1 − 𝑓𝑤 𝑥𝑖 𝑦𝑖
𝑤ℎ𝑒𝑟𝑒 𝑓𝑤 𝑥𝑖 𝑦𝑖
𝑑𝑒𝑛𝑜𝑡𝑒𝑠 𝑡ℎ𝑒 𝑦𝑖𝑡ℎ 𝑒𝑙𝑒𝑚𝑒𝑛𝑡 𝑜𝑓 𝑜𝑢𝑡𝑝𝑢𝑡 𝑣𝑒𝑐𝑡𝑜𝑟 𝑓𝑤 𝑥𝑖
𝑎𝑛𝑑 𝑥𝑖
′
𝑖𝑠 𝑓𝑟𝑜𝑚 𝑎𝑟𝑔 max
𝑥𝑖
′∈ℬ𝜖 𝑥𝑖
𝐶𝐸 𝑓𝑤 𝑥𝑖
′
, 𝑦𝑖
 正解例を当てることに焦点をあて,Adversarial Loss は負例(部)に関して足される
 adversarial examplesの生成には教師データが必要
from mart import mart_loss
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# calculate robust loss - MART loss
loss = mart_loss(model=model,
x_natural=data,
y=target,
optimizer=optimizer,
step_size=args.step_size,
epsilon=args.epsilon,
perturb_steps=args.num_steps,
beta=args.beta,
distance='l_inf')
loss.backward()
optimizer.step()
SMART
https://arxiv.org/abs/1911.03437 : Jiang H et al., 2021
1.正則化項の追加と2.Optimizationの工夫によって構成される
1. Smoothness-Inducing Adversarial Regularization: VATと同じ
2. Bregman Proximal Point Optimization: 学習パラメータ𝜃を更新前から大きく離れないよう更新
𝜃𝑡+1 = 𝑎𝑟𝑔 min
𝜃
ℒ 𝜃 + 𝜆𝑠𝑅𝑠 𝜃 + 𝜇𝐷𝐵𝑟𝑒𝑔 𝜃, 𝜃𝑡
≒VAT
𝜆𝑆, 𝜇: ℎ𝑦𝑝𝑒𝑟𝑝𝑎𝑟𝑎𝑚𝑒𝑡𝑒𝑟
𝑤ℎ𝑒𝑟𝑒 𝐷𝐵𝑟𝑒𝑔 𝜃, 𝜃𝑡 =
1
𝑛 𝑖=1
𝑛
∆𝐾𝐿 𝑓 𝑥𝑖; 𝜃 , 𝑓 𝑥𝑖; 𝜃𝑡 𝑓 𝑥; 𝜃 は入力xに対するoutput
https://github.com/namisan/mt-dnn
たぶんBregman Proximal Point Optimizationについてはgithubコードに実装されていない
AWP; Adversarial Weight Perturbation
https://arxiv.org/abs/2004.05884 : Wu D et al., 2020
 double-perturbation mechanism: both inputs and weights are adversarially perturbed
 weightの重みに摂動を加えた場合のモデル精度の不安定性(weight loss landscape)の低さが重要であると主張
⇒ 一般化に成功
𝑤𝑒𝑖𝑔ℎ𝑡 𝑙𝑜𝑠𝑠 𝑙𝑎𝑛𝑑𝑠𝑐𝑎𝑝𝑒
𝑔 𝛼 = 𝜌 𝑤 + 𝛼𝑑 =
1
𝑛
𝑖=1
𝑛
max
𝑥′𝑖−𝑥𝑖 𝑝≤𝜖
ℓ 𝑓𝑤+𝛼𝑑 𝑥𝑖
′
, 𝑦𝑖
𝑤ℎ𝑒𝑟𝑒 𝑑 𝑖𝑠 𝑠𝑎𝑚𝑝𝑙𝑒𝑑 𝑓𝑟𝑜𝑚 𝑎 𝐺𝑎𝑢𝑠𝑠𝑖𝑎𝑛 𝑑𝑖𝑠𝑡𝑟𝑖𝑏𝑢𝑡𝑖𝑜𝑛 𝑎𝑛𝑑 𝑓𝑖𝑙𝑡𝑒𝑟 𝑛𝑜𝑟𝑚𝑎𝑙𝑖𝑧𝑒𝑑 𝑏𝑦 𝑑𝑙,𝑗 ←
𝑑𝑙,𝑗
𝑑𝑙,𝑗 𝐹
𝑤𝑙,𝑗 𝐹
重みの摂動に対して安定
過学習の状態では重みの摂動に
対して不安定
https://github.com/csdongxian/AWP
gap
が小さいほど
Test Accuracy
は高い傾向
⇒ gapをLossに追加
 なぜweight perturbationが有効かの考察
• adversarial perturbation on inputsはそれぞれの入力についてモデルが不得意とするperturbationを与える
= “local” worst-case
• adversarial perturbation on weightsは全データに関して予測を(程よく)崩すようなperturbationを与える
= “global” worst-case
⇒ ともに助け合いながらRobust modelが学習される
min
𝑤
𝜌 𝑤 + 𝜌 𝑤 + 𝑣 − 𝜌 𝑤 → min
𝑤
𝜌 𝑤 + 𝑣 ただし𝜌 𝑤 は入力データに対するadversarial loss
より
min
𝑤
max
𝑣∈𝑉
1
𝑛
𝑖=1
𝑛
max
𝑥𝑖
′−𝑥𝑖 𝑝
≤𝜖
ℓ 𝑓𝑤+𝑣 𝑥𝑖
′
, 𝑦𝑖
このmaximizeは各batchについて計算されるので注意
batch-sizeは重要。
AWPは結果として大きさに関する
正則化としても機能している
AWP Code
https://github.com/namisan/mt-dnn では,at_AWPやtrades_AWPコードが公開されているので任意のモデルに応用できるはず
for batch_idx, (data, target) in enumerate(train_loader):
x_natural, target = data.to(device), target.to(device)
# craft adversarial examples
x_adv = perturb_input(model=model,
x_natural=x_natural,
step_size=step_size,
epsilon=epsilon,
perturb_steps=args.num_steps,
distance=args.norm)
model.train()
# calculate adversarial weight perturbation
if epoch >= args.awp_warmup:
awp = awp_adversary.calc_awp(inputs_adv=x_adv,
inputs_clean=x_natural,
targets=target,
beta=args.beta)
awp_adversary.perturb(awp)
optimizer.zero_grad()
logits_adv = model(x_adv)
loss_robust = F.kl_div(F.log_softmax(logits_adv, dim=1),
F.softmax(model(x_natural), dim=1),
reduction='batchmean')
# calculate natural loss and backprop
logits = model(x_natural)
loss_natural = F.cross_entropy(logits, target)
loss = loss_natural + args.beta * loss_robust
inputsに対するadversarial attack
weightsに対するadversarial attack
AWP Code
NBME top#1 Code
https://www.kaggle.com/code/wht1996/feedback-nn-train/notebook
 正直参考にした論文と結構異なるので混乱…
 inputに対するadversarial trainingはなし (たぶんpre-trainedだからだと思う…)
def attack_backward(self, x, y, attention_mask,epoch):
if (self.adv_lr == 0) or (epoch < self.start_epoch):
return None
self._save()
for i in range(self.adv_step):
self._attack_step()
with torch.cuda.amp.autocast():
adv_loss, tr_logits = self.model(input_ids=x, attention_mask=attention_mask, labels=y)
adv_loss = adv_loss.mean()
self.optimizer.zero_grad()
self.scaler.scale(adv_loss).backward()
self._restore()
def _attack_step(self):
e = 1e-6
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None and self.adv_param in name:
norm1 = torch.norm(param.grad)
norm2 = torch.norm(param.data.detach())
if norm1 != 0 and not torch.isnan(norm1):
r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e)
param.data.add_(r_at)
param.data = torch.min(
torch.max(param.data, self.backup_eps[name][0]), self.backup_eps[name][1]
)
# param.data.clamp_(*self.backup_eps[name])
# Define AWP class in advance
awp = AWP(model,
optimizer,
adv_lr=args.adv_lr,
adv_eps=args.adv_eps,
start_epoch=args.num_train_st
eps/args.epochs,
scaler=scaler)
# during train....
# logits = model(inputs)
# loss = ....
# loss.backward()
awp.attack_backward(input_ids, labels,
attention_mask, step)
# optimizer.step()
𝜌 𝑤 + 𝑣 の𝑣が
𝑣 = 𝛻𝑤ℒ という感じ??
FGM; Fast Gradient Method
https://www.kaggle.com/c/tweet-sentiment-extraction/discussion/143764
 一番最初のAdversarial trainingのこと
 inputsに対するadversarial attackだが,NLPの場合embeddingに対してかかるのでweightsに対するadversarial attackの
ように記述する
 書き方からして先ほどのAWPはこれを真似たのだろう class FGM():
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emb_name='word_embeddings'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='word_embeddings'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
fgm = FGM(model)
for batch_input, batch_label in data:
loss = model(batch_input, batch_label)
loss.backward()
# adversarial training
fgm.attack()
loss_adv = model(batch_input, batch_label)
loss_adv.backward()
fgm.restore()
optimizer.step()
model.zero_grad()
SiFT; Scale Invariant Fine-Tuning
https://github.com/microsoft/DeBERTa/tree/master/DeBERTa/sift
 FGMと同じ。embeddingについてGradient-base adversarial attackを行う

Weitere ähnliche Inhalte

Was ist angesagt?

backbone としての timm 入門
backbone としての timm 入門backbone としての timm 入門
backbone としての timm 入門Takuji Tahara
 
[DL輪読会]Transframer: Arbitrary Frame Prediction with Generative Models
[DL輪読会]Transframer: Arbitrary Frame Prediction with Generative Models[DL輪読会]Transframer: Arbitrary Frame Prediction with Generative Models
[DL輪読会]Transframer: Arbitrary Frame Prediction with Generative ModelsDeep Learning JP
 
Transformer メタサーベイ
Transformer メタサーベイTransformer メタサーベイ
Transformer メタサーベイcvpaper. challenge
 
[DL輪読会]Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) 表形式デー...
[DL輪読会]Revisiting Deep Learning Models for Tabular Data  (NeurIPS 2021) 表形式デー...[DL輪読会]Revisiting Deep Learning Models for Tabular Data  (NeurIPS 2021) 表形式デー...
[DL輪読会]Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) 表形式デー...Deep Learning JP
 
敵対的生成ネットワーク(GAN)
敵対的生成ネットワーク(GAN)敵対的生成ネットワーク(GAN)
敵対的生成ネットワーク(GAN)cvpaper. challenge
 
[DL輪読会]data2vec: A General Framework for Self-supervised Learning in Speech,...
[DL輪読会]data2vec: A General Framework for  Self-supervised Learning in Speech,...[DL輪読会]data2vec: A General Framework for  Self-supervised Learning in Speech,...
[DL輪読会]data2vec: A General Framework for Self-supervised Learning in Speech,...Deep Learning JP
 
【DL輪読会】ViT + Self Supervised Learningまとめ
【DL輪読会】ViT + Self Supervised Learningまとめ【DL輪読会】ViT + Self Supervised Learningまとめ
【DL輪読会】ViT + Self Supervised LearningまとめDeep Learning JP
 
グラフニューラルネットワーク入門
グラフニューラルネットワーク入門グラフニューラルネットワーク入門
グラフニューラルネットワーク入門ryosuke-kojima
 
PRML学習者から入る深層生成モデル入門
PRML学習者から入る深層生成モデル入門PRML学習者から入る深層生成モデル入門
PRML学習者から入る深層生成モデル入門tmtm otm
 
SSII2022 [TS1] Transformerの最前線〜 畳込みニューラルネットワークの先へ 〜
SSII2022 [TS1] Transformerの最前線〜 畳込みニューラルネットワークの先へ 〜SSII2022 [TS1] Transformerの最前線〜 畳込みニューラルネットワークの先へ 〜
SSII2022 [TS1] Transformerの最前線〜 畳込みニューラルネットワークの先へ 〜SSII
 
【DL輪読会】ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
【DL輪読会】ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders【DL輪読会】ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
【DL輪読会】ConvNeXt V2: Co-designing and Scaling ConvNets with Masked AutoencodersDeep Learning JP
 
【論文読み会】Deep Clustering for Unsupervised Learning of Visual Features
【論文読み会】Deep Clustering for Unsupervised Learning of Visual Features【論文読み会】Deep Clustering for Unsupervised Learning of Visual Features
【論文読み会】Deep Clustering for Unsupervised Learning of Visual FeaturesARISE analytics
 
[DL輪読会]GQNと関連研究,世界モデルとの関係について
[DL輪読会]GQNと関連研究,世界モデルとの関係について[DL輪読会]GQNと関連研究,世界モデルとの関係について
[DL輪読会]GQNと関連研究,世界モデルとの関係についてDeep Learning JP
 
SSII2019OS: 深層学習にかかる時間を短くしてみませんか? ~分散学習の勧め~
SSII2019OS: 深層学習にかかる時間を短くしてみませんか? ~分散学習の勧め~SSII2019OS: 深層学習にかかる時間を短くしてみませんか? ~分散学習の勧め~
SSII2019OS: 深層学習にかかる時間を短くしてみませんか? ~分散学習の勧め~SSII
 
近年のHierarchical Vision Transformer
近年のHierarchical Vision Transformer近年のHierarchical Vision Transformer
近年のHierarchical Vision TransformerYusuke Uchida
 
ドメイン適応の原理と応用
ドメイン適応の原理と応用ドメイン適応の原理と応用
ドメイン適応の原理と応用Yoshitaka Ushiku
 
[DL輪読会]Set Transformer: A Framework for Attention-based Permutation-Invariant...
[DL輪読会]Set Transformer: A Framework for Attention-based Permutation-Invariant...[DL輪読会]Set Transformer: A Framework for Attention-based Permutation-Invariant...
[DL輪読会]Set Transformer: A Framework for Attention-based Permutation-Invariant...Deep Learning JP
 
[DL輪読会]Wasserstein GAN/Towards Principled Methods for Training Generative Adv...
[DL輪読会]Wasserstein GAN/Towards Principled Methods for Training Generative Adv...[DL輪読会]Wasserstein GAN/Towards Principled Methods for Training Generative Adv...
[DL輪読会]Wasserstein GAN/Towards Principled Methods for Training Generative Adv...Deep Learning JP
 
NLPにおけるAttention~Seq2Seq から BERTまで~
NLPにおけるAttention~Seq2Seq から BERTまで~NLPにおけるAttention~Seq2Seq から BERTまで~
NLPにおけるAttention~Seq2Seq から BERTまで~Takuya Ono
 

Was ist angesagt? (20)

backbone としての timm 入門
backbone としての timm 入門backbone としての timm 入門
backbone としての timm 入門
 
[DL輪読会]Transframer: Arbitrary Frame Prediction with Generative Models
[DL輪読会]Transframer: Arbitrary Frame Prediction with Generative Models[DL輪読会]Transframer: Arbitrary Frame Prediction with Generative Models
[DL輪読会]Transframer: Arbitrary Frame Prediction with Generative Models
 
Transformer メタサーベイ
Transformer メタサーベイTransformer メタサーベイ
Transformer メタサーベイ
 
[DL輪読会]Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) 表形式デー...
[DL輪読会]Revisiting Deep Learning Models for Tabular Data  (NeurIPS 2021) 表形式デー...[DL輪読会]Revisiting Deep Learning Models for Tabular Data  (NeurIPS 2021) 表形式デー...
[DL輪読会]Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) 表形式デー...
 
敵対的生成ネットワーク(GAN)
敵対的生成ネットワーク(GAN)敵対的生成ネットワーク(GAN)
敵対的生成ネットワーク(GAN)
 
[DL輪読会]data2vec: A General Framework for Self-supervised Learning in Speech,...
[DL輪読会]data2vec: A General Framework for  Self-supervised Learning in Speech,...[DL輪読会]data2vec: A General Framework for  Self-supervised Learning in Speech,...
[DL輪読会]data2vec: A General Framework for Self-supervised Learning in Speech,...
 
【DL輪読会】ViT + Self Supervised Learningまとめ
【DL輪読会】ViT + Self Supervised Learningまとめ【DL輪読会】ViT + Self Supervised Learningまとめ
【DL輪読会】ViT + Self Supervised Learningまとめ
 
グラフニューラルネットワーク入門
グラフニューラルネットワーク入門グラフニューラルネットワーク入門
グラフニューラルネットワーク入門
 
PRML学習者から入る深層生成モデル入門
PRML学習者から入る深層生成モデル入門PRML学習者から入る深層生成モデル入門
PRML学習者から入る深層生成モデル入門
 
SSII2022 [TS1] Transformerの最前線〜 畳込みニューラルネットワークの先へ 〜
SSII2022 [TS1] Transformerの最前線〜 畳込みニューラルネットワークの先へ 〜SSII2022 [TS1] Transformerの最前線〜 畳込みニューラルネットワークの先へ 〜
SSII2022 [TS1] Transformerの最前線〜 畳込みニューラルネットワークの先へ 〜
 
【DL輪読会】ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
【DL輪読会】ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders【DL輪読会】ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
【DL輪読会】ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
 
【論文読み会】Deep Clustering for Unsupervised Learning of Visual Features
【論文読み会】Deep Clustering for Unsupervised Learning of Visual Features【論文読み会】Deep Clustering for Unsupervised Learning of Visual Features
【論文読み会】Deep Clustering for Unsupervised Learning of Visual Features
 
Semantic segmentation
Semantic segmentationSemantic segmentation
Semantic segmentation
 
[DL輪読会]GQNと関連研究,世界モデルとの関係について
[DL輪読会]GQNと関連研究,世界モデルとの関係について[DL輪読会]GQNと関連研究,世界モデルとの関係について
[DL輪読会]GQNと関連研究,世界モデルとの関係について
 
SSII2019OS: 深層学習にかかる時間を短くしてみませんか? ~分散学習の勧め~
SSII2019OS: 深層学習にかかる時間を短くしてみませんか? ~分散学習の勧め~SSII2019OS: 深層学習にかかる時間を短くしてみませんか? ~分散学習の勧め~
SSII2019OS: 深層学習にかかる時間を短くしてみませんか? ~分散学習の勧め~
 
近年のHierarchical Vision Transformer
近年のHierarchical Vision Transformer近年のHierarchical Vision Transformer
近年のHierarchical Vision Transformer
 
ドメイン適応の原理と応用
ドメイン適応の原理と応用ドメイン適応の原理と応用
ドメイン適応の原理と応用
 
[DL輪読会]Set Transformer: A Framework for Attention-based Permutation-Invariant...
[DL輪読会]Set Transformer: A Framework for Attention-based Permutation-Invariant...[DL輪読会]Set Transformer: A Framework for Attention-based Permutation-Invariant...
[DL輪読会]Set Transformer: A Framework for Attention-based Permutation-Invariant...
 
[DL輪読会]Wasserstein GAN/Towards Principled Methods for Training Generative Adv...
[DL輪読会]Wasserstein GAN/Towards Principled Methods for Training Generative Adv...[DL輪読会]Wasserstein GAN/Towards Principled Methods for Training Generative Adv...
[DL輪読会]Wasserstein GAN/Towards Principled Methods for Training Generative Adv...
 
NLPにおけるAttention~Seq2Seq から BERTまで~
NLPにおけるAttention~Seq2Seq から BERTまで~NLPにおけるAttention~Seq2Seq から BERTまで~
NLPにおけるAttention~Seq2Seq から BERTまで~
 

Ähnlich wie adversarial training.pptx

Azure Machine Learning Services 概要 - 2019年3月版
Azure Machine Learning Services 概要 - 2019年3月版Azure Machine Learning Services 概要 - 2019年3月版
Azure Machine Learning Services 概要 - 2019年3月版Daiyu Hatakeyama
 
20181212 - PGconf.ASIA - LT
20181212 - PGconf.ASIA - LT20181212 - PGconf.ASIA - LT
20181212 - PGconf.ASIA - LTKohei KaiGai
 
Wandb Monthly Meetup August 2023.pdf
Wandb Monthly Meetup August 2023.pdfWandb Monthly Meetup August 2023.pdf
Wandb Monthly Meetup August 2023.pdfYuya Yamamoto
 
第1回 Jubatusハンズオン
第1回 Jubatusハンズオン第1回 Jubatusハンズオン
第1回 JubatusハンズオンJubatusOfficial
 
第1回 Jubatusハンズオン
第1回 Jubatusハンズオン第1回 Jubatusハンズオン
第1回 JubatusハンズオンYuya Unno
 
20170127 JAWS HPC-UG#8
20170127 JAWS HPC-UG#820170127 JAWS HPC-UG#8
20170127 JAWS HPC-UG#8Kohei KaiGai
 
Okinawa.rb 第2回勉強会
Okinawa.rb 第2回勉強会Okinawa.rb 第2回勉強会
Okinawa.rb 第2回勉強会Naoki Takaesu
 
Learning Template Library Design using Boost.Geomtry
Learning Template Library Design using Boost.GeomtryLearning Template Library Design using Boost.Geomtry
Learning Template Library Design using Boost.GeomtryAkira Takahashi
 
Asakusa Enterprise Batch Processing Framework for Hadoop
Asakusa Enterprise Batch Processing Framework for HadoopAsakusa Enterprise Batch Processing Framework for Hadoop
Asakusa Enterprise Batch Processing Framework for HadoopTakashi Kambayashi
 
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説Daiyu Hatakeyama
 
新しい並列for構文のご提案
新しい並列for構文のご提案新しい並列for構文のご提案
新しい並列for構文のご提案yohhoy
 
Java ORマッパー選定のポイント #jsug
Java ORマッパー選定のポイント #jsugJava ORマッパー選定のポイント #jsug
Java ORマッパー選定のポイント #jsugMasatoshi Tada
 
Try_to_writecode_practicaltest #atest_hack
Try_to_writecode_practicaltest #atest_hackTry_to_writecode_practicaltest #atest_hack
Try_to_writecode_practicaltest #atest_hackkimukou_26 Kimukou
 
プログラミングで言いたいこと聞きたいこと集
プログラミングで言いたいこと聞きたいこと集プログラミングで言いたいこと聞きたいこと集
プログラミングで言いたいこと聞きたいこと集tecopark
 
プログラミングで言いたい聞きたいこと集
プログラミングで言いたい聞きたいこと集プログラミングで言いたい聞きたいこと集
プログラミングで言いたい聞きたいこと集tecopark
 
エンジニアのための機械学習の基礎
エンジニアのための機械学習の基礎エンジニアのための機械学習の基礎
エンジニアのための機械学習の基礎Daiyu Hatakeyama
 
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用de:code 2017
 
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdfウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdfYuya Yamamoto
 

Ähnlich wie adversarial training.pptx (20)

ADVENTURE_Solidの概要
ADVENTURE_Solidの概要ADVENTURE_Solidの概要
ADVENTURE_Solidの概要
 
Azure Machine Learning Services 概要 - 2019年3月版
Azure Machine Learning Services 概要 - 2019年3月版Azure Machine Learning Services 概要 - 2019年3月版
Azure Machine Learning Services 概要 - 2019年3月版
 
20181212 - PGconf.ASIA - LT
20181212 - PGconf.ASIA - LT20181212 - PGconf.ASIA - LT
20181212 - PGconf.ASIA - LT
 
Wandb Monthly Meetup August 2023.pdf
Wandb Monthly Meetup August 2023.pdfWandb Monthly Meetup August 2023.pdf
Wandb Monthly Meetup August 2023.pdf
 
第1回 Jubatusハンズオン
第1回 Jubatusハンズオン第1回 Jubatusハンズオン
第1回 Jubatusハンズオン
 
第1回 Jubatusハンズオン
第1回 Jubatusハンズオン第1回 Jubatusハンズオン
第1回 Jubatusハンズオン
 
20170127 JAWS HPC-UG#8
20170127 JAWS HPC-UG#820170127 JAWS HPC-UG#8
20170127 JAWS HPC-UG#8
 
Okinawa.rb 第2回勉強会
Okinawa.rb 第2回勉強会Okinawa.rb 第2回勉強会
Okinawa.rb 第2回勉強会
 
Learning Template Library Design using Boost.Geomtry
Learning Template Library Design using Boost.GeomtryLearning Template Library Design using Boost.Geomtry
Learning Template Library Design using Boost.Geomtry
 
Asakusa Enterprise Batch Processing Framework for Hadoop
Asakusa Enterprise Batch Processing Framework for HadoopAsakusa Enterprise Batch Processing Framework for Hadoop
Asakusa Enterprise Batch Processing Framework for Hadoop
 
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
Microsoft Open Tech Night: Azure Machine Learning - AutoML徹底解説
 
新しい並列for構文のご提案
新しい並列for構文のご提案新しい並列for構文のご提案
新しい並列for構文のご提案
 
Java ORマッパー選定のポイント #jsug
Java ORマッパー選定のポイント #jsugJava ORマッパー選定のポイント #jsug
Java ORマッパー選定のポイント #jsug
 
Try_to_writecode_practicaltest #atest_hack
Try_to_writecode_practicaltest #atest_hackTry_to_writecode_practicaltest #atest_hack
Try_to_writecode_practicaltest #atest_hack
 
プログラミングで言いたいこと聞きたいこと集
プログラミングで言いたいこと聞きたいこと集プログラミングで言いたいこと聞きたいこと集
プログラミングで言いたいこと聞きたいこと集
 
プログラミングで言いたい聞きたいこと集
プログラミングで言いたい聞きたいこと集プログラミングで言いたい聞きたいこと集
プログラミングで言いたい聞きたいこと集
 
CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎
CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎
CMSI計算科学技術特論A (2015) 第3回 OpenMPの基礎
 
エンジニアのための機械学習の基礎
エンジニアのための機械学習の基礎エンジニアのための機械学習の基礎
エンジニアのための機械学習の基礎
 
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
[AI08] 深層学習フレームワーク Chainer × Microsoft で広がる応用
 
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdfウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
ウェビナー:Nejumiリーダーボードを使った自社LLMモデルの独自評価.pdf
 

adversarial training.pptx

  • 1. Adversarial Training to avoid overfitting NBME top#2 Solution and Discussion https://www.kaggle.com/competitions/nbme-score- clinical-patient-notes/discussion/323085 Feedback top#1で活用されたが,NBMEでは “Although its CV score was quite higher than the one I selected above, both its public LB score and private LB score were lower. It seems that my way of doing pseudo labeling was better. It may be that being quite new to these techniques I didn't tune them correctly. I will try them in future competitions for sure.” とあるので汎用性についてはさらなる実装と議論が必要か。
  • 2. Adversarial Training Inputs Perturbation into “Local” worst-case Weights Perturbation into “Global” worst-case Gradient-based Adversary Not gradient-based Need Labels Not need Labels FGM, SiFT VAT, TRADES, SMART MART AWP
  • 3. Adversarial Training https://arxiv.org/abs/1412.6572 : Goodfellow IJ et al., 2015, ICLR 2015  摂動を加えた入力の中でモデルにとってhigh confidenceに間違えるようなもの = Adversarial Examples  Adversarial Examplesを作成しながらモデル精度を高める = Adversarial Training  ランダムノイズを加える点では一般的なaugmentationと捉えられるが,その中でもよりadversarialなもの Adversarial Example 自信をもって間違えている Perturbation is: 𝜂 = 𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝑦 loss backwardの時点で入力xについても自動微分が得られる lossを大きくする方向 (𝑠𝑖𝑔𝑛)に微小(𝜖)動かす 参考) https://ai-scholar.tech/articles/adversarial-perturbation/Earlystopping 目的関数に追加: 𝐽 𝜃, 𝑥, 𝑦 = 𝛼𝐽 𝜃, 𝑥, 𝑦 + 1 − 𝛼 𝐽 𝜃, 𝑥 + 𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝑦 , 𝑦
  • 4. VAT; Virtual Adversarial Training https://arxiv.org/pdf/1507.00677.pdf : Miyato T et al., 2016, ICLR 2016 https://arxiv.org/pdf/1704.03976v2.pdf : Miyato T et al., 2018 https://arxiv.org/pdf/1605.07725.pdf : Miyato T et al., 2021  モデル分散の平滑化を目的とした正則化として働く  当初のAdversarial Examples作成にはLabel (𝑦)情報が必要 (𝜖 𝑠𝑖𝑔𝑛 𝛻𝑥𝐽 𝜃, 𝑥, 𝒚 なので)だが,VATではPerturbationを与えた ときにmodel outputがどれくらい動くか(LDS)を用いるのでLabel不要  勾配情報を用いるとsignによってadversarial directionが決まるが,VATではLDSを探索することでdirectionを決めるた め,’virtual‘ adversarial trainingという名前がついた 𝑥 𝑛 𝑖𝑛𝑝𝑢𝑡 𝑠𝑝𝑎𝑐𝑒 𝑝 𝑦|𝑥 𝑛 , 𝜃 model output (e.g. classification prediction) なめらかな予測曲線 = 過学習しづらい (正しい学習) 過学習の傾向 観測されたデータ 0.9 ちょっとノイズが入ると予測が 大きく変動する ちょっと入力を動かしたときの予測自体 (精度ではない)が変動 するかをKL Div.で定量したものをLocal Distributional Smoothing (LDS)とした。 ∆𝐾𝐿 𝑟, 𝑥 𝑛 , 𝜃 ≡ 𝐾𝐿 𝑝 𝑦|𝑥 𝑛 , 𝜃 ||𝑝 𝑦|𝑥 𝑛 + 𝑟, 𝜃 𝑟𝑣−𝑎𝑑𝑣 𝑛 ≡ 𝑎𝑟𝑔 max 𝑟 ∆𝐾𝐿 𝑟, 𝑥 𝑛 , 𝜃 ; 𝑟 2 ≤ 𝜖 𝐿𝐷𝑆 𝑥 𝑛 , 𝜃 ≡ −∆𝐾𝐿 𝑟𝑣−𝑎𝑑𝑣 𝑛 , 𝑥 𝑛 , 𝜃 https://github.com/tensorflow/models/tree/master/research/adversarial_text LDSを正則化項として目的関数に追加
  • 5. 詰まるところ, 𝑟𝑣−𝑎𝑑𝑣 𝑛 を決めるのが大変 𝑟𝑣−𝑎𝑑𝑣 𝑛 ≡ 𝑎𝑟𝑔 max 𝑟 ∆𝐾𝐿 𝑟, 𝑥 𝑛 , 𝜃 ; 𝑟 2 ≤ 𝜖 これ自身も学習で求める 𝑖𝑛𝑝𝑢𝑡 に対して𝑟𝑎𝑛𝑑𝑜𝑚 𝑣𝑒𝑐𝑡𝑜𝑟 𝑑を初期化して以下SGDによって更新 𝑑 ← 𝛻𝑟𝐾𝐿 𝑟, 𝑥, 𝜃 𝑟=𝜉𝑑 𝑤ℎ𝑒𝑟𝑒 𝑣 = 𝑣 𝑣 2 ⋯ ⋯ 普通にloss求める LDS求める ⋯ 一旦勾配計算とめて ⋯ ⋯ adversarial_losses.py train_classifier.py KL Div.に対するdの勾配を求めて 𝑟𝑣−𝑎𝑑𝑣を得る https://github.com/tensorflow/models/tree/master/research/adversarial_text
  • 7. TRADES https://arxiv.org/abs/1901.08573 : Zhang H et al., 2019, ICML 2019 𝜌𝑇𝑅𝐴𝐷𝐸𝑆 𝑤 = 1 𝑛 𝑖=1 𝑛 𝐶𝐸 𝑓𝑤 𝑥𝑖 , 𝑦𝑖 + 𝛽 𝑚𝑎𝑥𝐾𝐿 𝑓𝑤 𝑥𝑖 ||𝑓𝑤 𝑥𝑖 ′  正直VATとの違いがあまり分からない  コードがとても使いやすい https://github.com/yaodongyu/TRADES from trades import trades_loss def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() # calculate robust loss - TRADES loss loss = trades_loss(model=model, x_natural=data, y=target, optimizer=optimizer, step_size=args.step_size, epsilon=args.epsilon, perturb_steps=args.num_steps, beta=args.beta, distance='l_inf') loss.backward() optimizer.step()
  • 8. MART https://openreview.net/forum?id=rklOg6EFwS : Wang Y et al., 2019, ICLR 2020 https://github.com/YisenWang/MART 𝜌𝑀𝐴𝑅𝑇 𝑤 = 1 𝑛 𝑖=1 𝑛 𝐵𝐶𝐸 𝑓𝑤 𝑥𝑖 ′ , 𝑦𝑖 + 𝜆 𝐾𝐿 𝑓𝑤 𝑥𝑖 ||𝑓𝑤 𝑥𝑖 ′ ∙ 1 − 𝑓𝑤 𝑥𝑖 𝑦𝑖 𝑤ℎ𝑒𝑟𝑒 𝑓𝑤 𝑥𝑖 𝑦𝑖 𝑑𝑒𝑛𝑜𝑡𝑒𝑠 𝑡ℎ𝑒 𝑦𝑖𝑡ℎ 𝑒𝑙𝑒𝑚𝑒𝑛𝑡 𝑜𝑓 𝑜𝑢𝑡𝑝𝑢𝑡 𝑣𝑒𝑐𝑡𝑜𝑟 𝑓𝑤 𝑥𝑖 𝑎𝑛𝑑 𝑥𝑖 ′ 𝑖𝑠 𝑓𝑟𝑜𝑚 𝑎𝑟𝑔 max 𝑥𝑖 ′∈ℬ𝜖 𝑥𝑖 𝐶𝐸 𝑓𝑤 𝑥𝑖 ′ , 𝑦𝑖  正解例を当てることに焦点をあて,Adversarial Loss は負例(部)に関して足される  adversarial examplesの生成には教師データが必要 from mart import mart_loss def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() # calculate robust loss - MART loss loss = mart_loss(model=model, x_natural=data, y=target, optimizer=optimizer, step_size=args.step_size, epsilon=args.epsilon, perturb_steps=args.num_steps, beta=args.beta, distance='l_inf') loss.backward() optimizer.step()
  • 9. SMART https://arxiv.org/abs/1911.03437 : Jiang H et al., 2021 1.正則化項の追加と2.Optimizationの工夫によって構成される 1. Smoothness-Inducing Adversarial Regularization: VATと同じ 2. Bregman Proximal Point Optimization: 学習パラメータ𝜃を更新前から大きく離れないよう更新 𝜃𝑡+1 = 𝑎𝑟𝑔 min 𝜃 ℒ 𝜃 + 𝜆𝑠𝑅𝑠 𝜃 + 𝜇𝐷𝐵𝑟𝑒𝑔 𝜃, 𝜃𝑡 ≒VAT 𝜆𝑆, 𝜇: ℎ𝑦𝑝𝑒𝑟𝑝𝑎𝑟𝑎𝑚𝑒𝑡𝑒𝑟 𝑤ℎ𝑒𝑟𝑒 𝐷𝐵𝑟𝑒𝑔 𝜃, 𝜃𝑡 = 1 𝑛 𝑖=1 𝑛 ∆𝐾𝐿 𝑓 𝑥𝑖; 𝜃 , 𝑓 𝑥𝑖; 𝜃𝑡 𝑓 𝑥; 𝜃 は入力xに対するoutput https://github.com/namisan/mt-dnn たぶんBregman Proximal Point Optimizationについてはgithubコードに実装されていない
  • 10. AWP; Adversarial Weight Perturbation https://arxiv.org/abs/2004.05884 : Wu D et al., 2020  double-perturbation mechanism: both inputs and weights are adversarially perturbed  weightの重みに摂動を加えた場合のモデル精度の不安定性(weight loss landscape)の低さが重要であると主張 ⇒ 一般化に成功 𝑤𝑒𝑖𝑔ℎ𝑡 𝑙𝑜𝑠𝑠 𝑙𝑎𝑛𝑑𝑠𝑐𝑎𝑝𝑒 𝑔 𝛼 = 𝜌 𝑤 + 𝛼𝑑 = 1 𝑛 𝑖=1 𝑛 max 𝑥′𝑖−𝑥𝑖 𝑝≤𝜖 ℓ 𝑓𝑤+𝛼𝑑 𝑥𝑖 ′ , 𝑦𝑖 𝑤ℎ𝑒𝑟𝑒 𝑑 𝑖𝑠 𝑠𝑎𝑚𝑝𝑙𝑒𝑑 𝑓𝑟𝑜𝑚 𝑎 𝐺𝑎𝑢𝑠𝑠𝑖𝑎𝑛 𝑑𝑖𝑠𝑡𝑟𝑖𝑏𝑢𝑡𝑖𝑜𝑛 𝑎𝑛𝑑 𝑓𝑖𝑙𝑡𝑒𝑟 𝑛𝑜𝑟𝑚𝑎𝑙𝑖𝑧𝑒𝑑 𝑏𝑦 𝑑𝑙,𝑗 ← 𝑑𝑙,𝑗 𝑑𝑙,𝑗 𝐹 𝑤𝑙,𝑗 𝐹 重みの摂動に対して安定 過学習の状態では重みの摂動に 対して不安定 https://github.com/csdongxian/AWP gap が小さいほど Test Accuracy は高い傾向 ⇒ gapをLossに追加
  • 11.  なぜweight perturbationが有効かの考察 • adversarial perturbation on inputsはそれぞれの入力についてモデルが不得意とするperturbationを与える = “local” worst-case • adversarial perturbation on weightsは全データに関して予測を(程よく)崩すようなperturbationを与える = “global” worst-case ⇒ ともに助け合いながらRobust modelが学習される min 𝑤 𝜌 𝑤 + 𝜌 𝑤 + 𝑣 − 𝜌 𝑤 → min 𝑤 𝜌 𝑤 + 𝑣 ただし𝜌 𝑤 は入力データに対するadversarial loss より min 𝑤 max 𝑣∈𝑉 1 𝑛 𝑖=1 𝑛 max 𝑥𝑖 ′−𝑥𝑖 𝑝 ≤𝜖 ℓ 𝑓𝑤+𝑣 𝑥𝑖 ′ , 𝑦𝑖 このmaximizeは各batchについて計算されるので注意 batch-sizeは重要。 AWPは結果として大きさに関する 正則化としても機能している
  • 12. AWP Code https://github.com/namisan/mt-dnn では,at_AWPやtrades_AWPコードが公開されているので任意のモデルに応用できるはず for batch_idx, (data, target) in enumerate(train_loader): x_natural, target = data.to(device), target.to(device) # craft adversarial examples x_adv = perturb_input(model=model, x_natural=x_natural, step_size=step_size, epsilon=epsilon, perturb_steps=args.num_steps, distance=args.norm) model.train() # calculate adversarial weight perturbation if epoch >= args.awp_warmup: awp = awp_adversary.calc_awp(inputs_adv=x_adv, inputs_clean=x_natural, targets=target, beta=args.beta) awp_adversary.perturb(awp) optimizer.zero_grad() logits_adv = model(x_adv) loss_robust = F.kl_div(F.log_softmax(logits_adv, dim=1), F.softmax(model(x_natural), dim=1), reduction='batchmean') # calculate natural loss and backprop logits = model(x_natural) loss_natural = F.cross_entropy(logits, target) loss = loss_natural + args.beta * loss_robust inputsに対するadversarial attack weightsに対するadversarial attack
  • 13. AWP Code NBME top#1 Code https://www.kaggle.com/code/wht1996/feedback-nn-train/notebook  正直参考にした論文と結構異なるので混乱…  inputに対するadversarial trainingはなし (たぶんpre-trainedだからだと思う…) def attack_backward(self, x, y, attention_mask,epoch): if (self.adv_lr == 0) or (epoch < self.start_epoch): return None self._save() for i in range(self.adv_step): self._attack_step() with torch.cuda.amp.autocast(): adv_loss, tr_logits = self.model(input_ids=x, attention_mask=attention_mask, labels=y) adv_loss = adv_loss.mean() self.optimizer.zero_grad() self.scaler.scale(adv_loss).backward() self._restore() def _attack_step(self): e = 1e-6 for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None and self.adv_param in name: norm1 = torch.norm(param.grad) norm2 = torch.norm(param.data.detach()) if norm1 != 0 and not torch.isnan(norm1): r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e) param.data.add_(r_at) param.data = torch.min( torch.max(param.data, self.backup_eps[name][0]), self.backup_eps[name][1] ) # param.data.clamp_(*self.backup_eps[name]) # Define AWP class in advance awp = AWP(model, optimizer, adv_lr=args.adv_lr, adv_eps=args.adv_eps, start_epoch=args.num_train_st eps/args.epochs, scaler=scaler) # during train.... # logits = model(inputs) # loss = .... # loss.backward() awp.attack_backward(input_ids, labels, attention_mask, step) # optimizer.step() 𝜌 𝑤 + 𝑣 の𝑣が 𝑣 = 𝛻𝑤ℒ という感じ??
  • 14. FGM; Fast Gradient Method https://www.kaggle.com/c/tweet-sentiment-extraction/discussion/143764  一番最初のAdversarial trainingのこと  inputsに対するadversarial attackだが,NLPの場合embeddingに対してかかるのでweightsに対するadversarial attackの ように記述する  書き方からして先ほどのAWPはこれを真似たのだろう class FGM(): def __init__(self, model): self.model = model self.backup = {} def attack(self, epsilon=1., emb_name='word_embeddings'): for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: self.backup[name] = param.data.clone() norm = torch.norm(param.grad) if norm != 0: r_at = epsilon * param.grad / norm param.data.add_(r_at) def restore(self, emb_name='word_embeddings'): for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: assert name in self.backup param.data = self.backup[name] self.backup = {} fgm = FGM(model) for batch_input, batch_label in data: loss = model(batch_input, batch_label) loss.backward() # adversarial training fgm.attack() loss_adv = model(batch_input, batch_label) loss_adv.backward() fgm.restore() optimizer.step() model.zero_grad() SiFT; Scale Invariant Fine-Tuning https://github.com/microsoft/DeBERTa/tree/master/DeBERTa/sift  FGMと同じ。embeddingについてGradient-base adversarial attackを行う