A. J. Ratner, H. R. Ehrenberg, et al., “Learning to Compose Domain-Specific Transformations for Data Augmentation”, NIPS2017
(https://papers.nips.cc/paper/6916-learning-to-compose-domain-specific-transformations-for-data-augmentation)
Slides for NIPS2017 paper reading meet-up @Tokyo
https://abeja-innovation-meetup.connpass.com/event/75189/
Learning to Compose Domain-Specific Transformations for Data Augmentation
1. Learning to Compose Domain-
Specific Transformations for
Data Augmentation
Tatsuya Shirakawa
tatsuya@abeja.asia
2. ABEJA, Inc. (Researcher)
- Deep Learning
- Computer Vision
- Natural Language Processing
- Graph Convolution / Graph Embedding
- Mathematical Optimization
- https://github.com/TatsuyaShiraka
tech blog → http://tech-blog.abeja.asia/
Poincaré Embeddings Graph Convolution
We are hiring! → https://www.abeja.asia/recruit/
→ https://six.abejainc.com/
3. A. J. Ratner, H. R. Ehrenberg, et al., “Learning to Compose Domain-
Specific Transformations for Data Augmentation”, NIPS2017
Today’s Paper
3
Problem to solve
• Learning how to compose predefined
data transformations (TFs) to create
naturally transformed data (data
augmentation)
How to solve
• Formulate the problem as a sequence
generation problem
• Learned by policy gradient method
6. Applying sequence of transformation functions
(TFs) to each data to augment dataset
Data Augmentation (DA)
6
7. Common Assumption
Transformed data are natural and essential
informations (e.g. classes) are kept unchanged
… But massive DA can easily break the assumption
DA can break informations
7
(CIFAR-10)
8. • Generator generates sequences of TFs
• Discriminator discriminates transformed
data are realistic or not
• End model (learned afterward)
This Paper — Learning to Compose TFs
8
G
D
Df
Technical Remarks: transformation sequences have same length L
10. • Discriminator discriminate whether given data
are realistic (1) or not (0)
• Relaxed Assumption
TFs preserve essential information or collapse it
Discriminator
10
11. Generator G is adversarially learned against D
This leads G to generate transformation sequences
that don’t collapse data
Generative Adversarial Objective
11Technical Remarks: Generator is not conditioned on data
12. Generator should not learn null transformation
sequences, so maximize
Examples of Null transformation sequence
• Horizontal Flip x 2
• Rotate left 5° and rotate right 5°
Diversity Objective
12
14. • We can optimize discriminator and generator
alternatively
• Optimization of discriminator can be done
by simple gradient ascent method
• Optimization of generator needs
optimization of sequence generation
process and cannot be applied simple
gradient descent method
Optimization
14
G
D
15. Reformulate the optimization problem for G as a
sequential decision making (RL) problem
Optimization of G — RL problem
15
…
h⌧1
h⌧2
h⌧L
x ˜x1 ˜x2 ˜xL
r1 r2 rL
Technical Remarks: loss is defined as loss(x) = log(1-D(x)) in the paper
rt = loss(˜xt) loss(˜xt 1),
LX
t=1
rt = loss(˜xL) loss(x)
16. Final loss
can be minimized by policy gradient method
Optimization of G — Policy Gradient
16
π … stochastic transition policy
implicitly defined by G
Policy Gradient Method
1.Generate samples (run the policy)
2.Estimate return
3.Improve the policy ✓ ✓ ⌘r✓U(✓)
17. Independent Model — Mean Field Model
learning task-specific “accuracy” and “frequency”
of each TF
e.g.
State-based Model — LSTM
some combination of TFs might be very lossy
(e.g. blur -> zoom, brighten -> saturation)
Generator (Policy) Model
17
18. • D measures whether data are realistic or not
• G (mean field / LSTM) generate sequences of TFs of length L
• Adversarial training for G & D
• Standard gradient ascent method for D
• Policy gradient method for G
Summary of Proposed Method
18
22. • MNIST
• CIFAR-10
Datasets — ACE corpus
22
• ACE corpus • Mammography Tumor-
Classification Dataset
(DDSM)
The goal is to identify
mentions of employer-
employee relations in
news articles
Conditional word swap TF
1.Construct trigram
language model
2.Sample a word
conditioned on the
preceding words
23. • MNIST
• CIFAR-10
Datasets — DDSM dataset
23
• ACE corpus • Mammography Tumor-
Classification Dataset
(DDSM)
Standard image TFs
Subselected so as not to
break class-invariance
Segmentation-based TFs
1.Segment the tumor mass
2.Perform TFs
(e.g. rotation or shifting)
3.Stitch it into a randomly-
sampled benign tissue
image
24. Results — CIFAR-10 Classification
24
Basic … random crop
Heur. … random composition of TFs
+ DS … allowing domain-specific TFs (semantic-segmentation-based)
26. Results — Training Progress on MNIST
26
https://hazyresearch.github.io/snorkel/blog/tanda.html
27. • Adversarial Training for Data Augmentation
• Optimization with standard/policy gradient method
• Achieved better performance on several datasets
Summary
27