Diabetic retinopathy (DR) is the leading cause of blindness for people aged 20 to 64, and afflicts more than 120 million people worldwide. Fortunately, vigilant monitoring greatly improves the chance to preserve one’s eyesight. This work used deep learning to analyze images of the retina and fundus for automated diagnosis of DR on a grading scale from 0 (normal) to 4 (severe). We achieved substantial improvement in accuracy compared to traditional approaches and continued advances by using a small auxiliary dataset that provided low-effort, high-value supervision. Data for training and testing, provided by the 2015 Kaggle Data Science Competition with over 80,000 high resolution images (>4 megapixels), required Amazon EC2 scalability to provide the GPU hardware needed to train a convolutional network with over 2 million parameters. For the competition, we focused on accurately modeling the scoring system, penalizing bad mistakes more severely, and combatting the over-prevalence of grade-0 examples in the dataset. We explored ideas first at low resolution on low-cost single-GPU instances. After finding the best methodology, we showed it could be scaled to equivalent improvements at high resolution, using the more expensive quad-GPU instances more effectively. This prototype model placed 15 out of 650 teams across the world with a kappa score of 0.78. We’ve now advanced the model via a new architecture that integrates the prototype and a new network specialized in finding dot hemorrhages, critical to identifying early DR. By annotating a small set of 200 images for hemorrhages, the performance jumped to a kappa of 0.82. We believe strategies that employ a bit more supervision for more effective learning are pivotal for cracking deep learning’s greatest weakness: its voracious appetite for data.
2. Problem, Data and Motivation
Motivation:
Affects ~100M, many in developed, ~45% of diabetics
Make process faster, assist ophthalmologist, self-help
Widespread disease, enable early diagnosis/care
Given fundus image
Rate severity of Diabetic Retinopathy
5 Classes: 0 (Normal), 1, 2, 3, 4 (Severe)
Hard classification (may solve as ordinal though)
Metric: quadratic weighted kappa, (pred – real)2 penalty
Data from Kaggle (California Healthcare Foundation, EyePACS)
~35,000 training images, ~54,000 test images
High resolution: variable, more than 2560 x 1920
3. Problem, Data and Motivation
Motivation:
Affects ~100M, many in developed, ~45% of diabetics
Make process faster, assist ophthalmologist, self-help
Widespread disease, enable early diagnosis/care
Given fundus image
Rate severity of Diabetic Retinopathy
5 Classes: 0 (Normal), 1, 2, 3, 4 (Severe)
Hard classification (may solve as ordinal though)
Metric: quadratic weighted kappa, (pred – real)2 penalty
Data from Kaggle (California Healthcare Foundation, EyePACS)
~35,000 training images, ~54,000 test images
High resolution: variable, more than 2560 x 1920
5. Problem, Data and Motivation
Motivation:
Affects ~100M, many in developed, ~45% of diabetics
Make process faster, assist ophthalmologist, self-help
Widespread disease, enable early diagnosis/care
Given fundus image
Rate severity of Diabetic Retinopathy
5 Classes: 0 (Normal), 1, 2, 3, 4 (Severe)
Hard classification (may solve as ordinal though)
Metric: quadratic weighted kappa, (pred – real)2 penalty
Data from Kaggle (California Healthcare Foundation, EyePACS)
~35,000 training images, ~54,000 test images
High resolution: variable, more than 2560 x 1920
6. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
Image size Batch Size
224 x 224 128
2K x 2K 2
7. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
Class 0 1
2 3
4
8. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples Class 2
9. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
10. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
- Mentioned in problem statement
- Confirmed with doctors
11. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
12. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
- Hard classification non-differentiable
- Backprop difficult
0 1
Truth
2 3 4
Penalty/Loss
Class
13. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
- Hard classification non-differentiable
- Backprop difficult
0 1
Truth
2 3 4
Predict
1
Penalty/Loss
Class
14. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
- Hard classification non-differentiable
- Backprop difficult
0 1
Truth
2 3 4
Predict
2
Penalty/Loss
Class
15. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
- Hard classification non-differentiable
- Backprop difficult
0 1
Truth
2 3 4
Predict
3
Penalty/Loss
Class
16. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
- Hard classification non-differentiable
- Backprop difficult
0 1
Truth
2 3 4
Penalty/Loss
Class
17. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
- Squared error approximation?
- Differentiable
0 1
Truth
2 3 4
Penalty/Loss
Class2.5
18. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
- Naïve: 3 class problem, or all zeros!
- Learn all classes separately: 1 vs All?
- Balanced while training
- At test time?
19. Challenges
High resolution images
Atypical in vision, GPU batch size issues
Discriminative features small
Grading criteria:
not clear (EyePACS guidelines)
learn from data
Incorrect labeling
Artifacts in ~40% images
Optimizing approach to QWK
Severe class imbalance
class 0 dominates
Too few training examples
- Big learning models take more data!
- Harness test set?
20. Conventional Approaches
Literature survey:
Hand-designed features to pick each
component
Clean images, small datasets
Optic disk, exudate segmentation: fail
due to artifacts
SVM: poor performance
21. Conventional Approaches
Literature survey:
Hand-designed features to pick each
component
Clean images, small datasets
Optic disk, exudate segmentation: fail
due to artifacts
SVM: poor performance
23. Step 1: Pre-processing
Registration
Hough circles, remove outside portion
Downsize to common size (224 x 224, 1K x 1K)
Color correction
Normalization (mean, variance)
24. Step 2: CNNs
3 Conv layers
(depth 96)
MaxPool (stride2)
3 Conv layers
(depth 384)
MaxPool (stride2)
3 Conv layers
(depth 1024)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
3 Conv layers
(depth 256)
MaxPool (stride2)
Network in Network architecture
7.5M parameters
No FC layers, spatial average pooling instead
Transfer learning (ImageNet)
Variable learning rates
Low for “ImageNet” layers
Schedule
Combat lack of data, over-fitting
Dropout, Early stopping
Data augmentation (flips, rotation)
25. Step 2: CNNs
3 Conv layers
(depth 96)
MaxPool (stride2)
3 Conv layers
(depth 384)
MaxPool (stride2)
3 Conv layers
(depth 1024)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
3 Conv layers
(depth 256)
MaxPool (stride2)
Network in Network architecture
7.5M parameters
No FC layers, spatial average pooling instead
Transfer learning (ImageNet)
Variable learning rates
Low for “ImageNet” layers
Schedule
Combat lack of data, over-fitting
Dropout, Early stopping
Data augmentation (flips, rotation)
26. Step 2: CNNs
3 Conv layers
(depth 96)
MaxPool (stride2)
3 Conv layers
(depth 384)
MaxPool (stride2)
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
3 Conv layers
(depth 256)
MaxPool (stride2)
Network in Network architecture
2.2M parameters
No FC layers, spatial average pooling instead
Transfer learning (ImageNet)
Variable learning rates
Low for “ImageNet” layers
Schedule
Combat lack of data, over-fitting
Dropout, Early stopping
Data augmentation (flips, rotation)
27. Step 2: CNNs
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
Network in Network architecture
2.2M parameters
No FC layers, spatial average pooling instead
Transfer learning (ImageNet)
Variable learning rates
Low for “ImageNet” layers
Schedule
Combat lack of data, over-fitting
Dropout, Early stopping
Data augmentation (flips, rotation)
28. Step 2: CNNs
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
Network in Network architecture
2.2M parameters
No FC layers, spatial average pooling instead
Transfer learning (ImageNet)
Variable learning rates
Low for “ImageNet” layers
Schedule
Combat lack of data, over-fitting
Dropout, Early stopping
Data augmentation (flips, rotation)
29. Step 2: CNNs
Network in Network architecture
2.2M parameters
No FC layers, spatial average pooling instead
Transfer learning (ImageNet)
Variable learning rates
Low for “ImageNet” layers
Schedule
Combat lack of data, over-fitting
Dropout, Early stopping
Data augmentation (flips, rotation)
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
30. Step 2: CNN Experiments
What image size to use?
Strategize using 224 x 224 -> extend to 1024 x 1024
What loss function?
Mean squared error (MSE)
Negative Log Likelihood (NLL)
Linear Combination (annealing)
Class imbalance
Even sampling -> true sampling
31. Step 2: CNN Experiments
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
No
learning
Loss Function Sampling Result
Image size: 224 x 224
32. Step 2: CNN Experiments
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
No
learning
Loss Function Sampling Result
MSE Fails to learn
Image size: 224 x 224
33. Step 2: CNN Experiments
Loss Function Sampling Result
MSE Fails to learn
MSE Fails to learn
Image size: 224 x 224
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
No
learning
34. Step 2: CNN Experiments
Loss Function Sampling Result
MSE Fails to learn
MSE Fails to learn
NLL Kappa < 0.1
Image size: 224 x 224
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
No
learning
35. Step 2: CNN Experiments
Loss Function Sampling Result
MSE Fails to learn
MSE Fails to learn
NLL Kappa < 0.1
NLL Kappa = 0.29
Image size: 224 x 224
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
No
learning
36. Step 2: CNN Experiments
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
0.01x
step size
Loss Function Sampling Result
NLL
(top layers only)
Kappa = 0.29
Image size: 224 x 224
37. Step 2: CNN Experiments
Loss Function Sampling Result
NLL
(top layers only)
Kappa = 0.29
NLL Kappa = 0.42
Image size: 224 x 224
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
0.01x
step size
38. Step 2: CNN Experiments
Loss Function Sampling Result
NLL
(top layers only)
Kappa = 0.29
NLL Kappa = 0.42
NLL Kappa = 0.51
Image size: 224 x 224
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
0.01x
step size
39. Step 2: CNN Experiments
Loss Function Sampling Result
NLL
(top layers only)
Kappa = 0.29
NLL Kappa = 0.42
NLL Kappa = 0.51
MSE Kappa = 0.56
Image size: 224 x 224
3 Conv layers
(depth 384, 64, 5)
MaxPool (stride2)
AvgPool
Input Image
Class probabilities
0.01x
step size
42. Computing Setup
Amazon EC2: GPU nodes,VPC, Amazon EBS-optimized
Single GPU nodes for 224 x 224 (g2.2xlarge)
Multi-GPU nodes for 1K x 1K (g2.8xlarge)
EBS, Amazon S3
Used Python for processing
Torch library (Lua) for training
44. Computing Setup
Data 1 Data 2EBS (gp2) EBS (gp2)
Snapshot (S3)
Model
Expt.
GPU node on EC2
45. Computing Setup
Master
Data 1 Data 2
Central Node
Model 2
Model 1
Model 10
EBS (gp2)
…
EBS-optimized
EBS (gp2)
Snapshot (S3)
VPC on EC2
Model
Expt.
GPU node on EC2
46. Computing Setup
Master
Data 1 Data 2
Central Node
Model 2
Model 1
Model 10
EBS (gp2)
…
EBS-optimized
EBS (gp2)
Snapshot (S3)
VPC on EC2
Model
Expt.
GPU node on EC2
~200 MB/s
47. Computing Setup
Master 2
Data 1 Data 2
Central Node
Model 12
Model 11
Model 20
EBS (gp2)
…
EBS-optimized
EBS (gp2)
Snapshot (S3)
VPC on EC2
Master 1
Central Node
Model 2
Model 1
Model 10…
EBS-optimized VPC on EC2
57. Train Lesion Detector
Only hemorrhages so far
Positives: 1866 extracted patches from 216
images/subjects
Negatives: ~25k class-0 images
Pre-processing/augmentation
Crop random 256 x 256 image from input, flips
Pre-trained Network in Network architecture
Accuracy: 99% for Negatives, 76% for Positives
58. Train Lesion Detector
Only hemorrhages so far
Positives: 1866 extracted patches from 216
images/subjects
Negatives: ~25k class-0 images
Pre-processing/augmentation
Crop random 256 x 256 image from input, flips
Pre-trained Network in Network architecture
Accuracy: 99% for Negatives, 76% for Positives
59. Train Lesion Detector
Only hemorrhages so far
Positives: 1866 extracted patches from 216
images/subjects
Negatives: ~25k class-0 images
Pre-processing/augmentation
Crop random 256 x 256 image from input, flips
Pre-trained Network in Network architecture
Accuracy: 99% for Negatives, 76% for Positives
60. Train Lesion Detector
Only hemorrhages so far
Positives: 1866 extracted patches from 216
images/subjects
Negatives: ~25k class-0 images
Pre-processing/augmentation
Crop random 256 x 256 image from input, flips
Pre-trained Network in Network architecture
Accuracy: 99% for Negatives, 76% for Positives
61. Train Lesion Detector
Only hemorrhages so far
Positives: 1866 extracted patches from 216
images/subjects
Negatives: ~25k class-0 images
Pre-processing/augmentation
Crop random 256 x 256 image from input, flips
Pre-trained Network in Network architecture
Accuracy: 99% for Negatives, 76% for Positives
63. Hybrid Architecture
64 tiles of
256 x 256
64 x 31 x 312 x 31 x 31
66 x 31 x 31
2048 1024
2 Conv layers
Main
Network
Fuse
Class probabilities
Lesion
Detector
64. Hybrid Architecture
64 tiles of
256 x 256
64 x 31 x 312 x 31 x 31
66 x 31 x 31
2048 1024
2 Conv layers
Main
Network
Fuse
Class probabilities
Lesion
Detector
2 x 56 x56