MasonCrinr
commited on
Commit
•
8026e91
1
Parent(s):
1709f73
Upload 331 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +1 -0
- .gitignore +133 -0
- LICENSE +22 -0
- MANIFEST.in +2 -0
- README.md +284 -0
- apex/.gitignore +5 -0
- apex/.nojekyll +0 -0
- apex/LICENSE +11 -0
- apex/README.md +99 -0
- apex/apex.patch +42 -0
- apex/apex/RNN/README.md +1 -0
- apex/apex/RNN/RNNBackend.py +365 -0
- apex/apex/RNN/__init__.py +3 -0
- apex/apex/RNN/cells.py +84 -0
- apex/apex/RNN/models.py +54 -0
- apex/apex/__init__.py +13 -0
- apex/apex/amp/README.md +72 -0
- apex/apex/amp/__init__.py +5 -0
- apex/apex/amp/__version__.py +2 -0
- apex/apex/amp/_amp_state.py +70 -0
- apex/apex/amp/_initialize.py +268 -0
- apex/apex/amp/_process_optimizer.py +411 -0
- apex/apex/amp/amp.py +177 -0
- apex/apex/amp/compat.py +42 -0
- apex/apex/amp/frontend.py +399 -0
- apex/apex/amp/handle.py +280 -0
- apex/apex/amp/lists/__init__.py +0 -0
- apex/apex/amp/lists/functional_overrides.py +77 -0
- apex/apex/amp/lists/tensor_overrides.py +63 -0
- apex/apex/amp/lists/torch_overrides.py +103 -0
- apex/apex/amp/opt.py +103 -0
- apex/apex/amp/rnn_compat.py +53 -0
- apex/apex/amp/scaler.py +210 -0
- apex/apex/amp/utils.py +213 -0
- apex/apex/amp/wrap.py +276 -0
- apex/apex/fp16_utils/README.md +16 -0
- apex/apex/fp16_utils/__init__.py +16 -0
- apex/apex/fp16_utils/fp16_optimizer.py +643 -0
- apex/apex/fp16_utils/fp16util.py +187 -0
- apex/apex/fp16_utils/loss_scaler.py +186 -0
- apex/apex/multi_tensor_apply/__init__.py +4 -0
- apex/apex/multi_tensor_apply/multi_tensor_apply.py +30 -0
- apex/apex/normalization/__init__.py +1 -0
- apex/apex/normalization/fused_layer_norm.py +160 -0
- apex/apex/optimizers/__init__.py +2 -0
- apex/apex/optimizers/fp16_optimizer.py +274 -0
- apex/apex/optimizers/fused_adam.py +147 -0
- apex/apex/parallel/LARC.py +97 -0
- apex/apex/parallel/README.md +66 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tensorboardX/screenshots/image.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Global
|
2 |
+
.DS_Store
|
3 |
+
.idea
|
4 |
+
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
pip-wheel-metadata/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
.python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
99 |
+
__pypackages__/
|
100 |
+
|
101 |
+
# Celery stuff
|
102 |
+
celerybeat-schedule
|
103 |
+
celerybeat.pid
|
104 |
+
|
105 |
+
# SageMath parsed files
|
106 |
+
*.sage.py
|
107 |
+
|
108 |
+
# Environments
|
109 |
+
.env
|
110 |
+
.venv
|
111 |
+
env/
|
112 |
+
venv/
|
113 |
+
ENV/
|
114 |
+
env.bak/
|
115 |
+
venv.bak/
|
116 |
+
|
117 |
+
# Spyder project settings
|
118 |
+
.spyderproject
|
119 |
+
.spyproject
|
120 |
+
|
121 |
+
# Rope project settings
|
122 |
+
.ropeproject
|
123 |
+
|
124 |
+
# mkdocs documentation
|
125 |
+
/site
|
126 |
+
|
127 |
+
# mypy
|
128 |
+
.mypy_cache/
|
129 |
+
.dmypy.json
|
130 |
+
dmypy.json
|
131 |
+
|
132 |
+
# Pyre type checker
|
133 |
+
.pyre/
|
LICENSE
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Noncommercial Use License
|
2 |
+
|
3 |
+
Software Copyright (c) 2020 OpenAI
|
4 |
+
|
5 |
+
We don’t claim ownership of the content you create with Jukebox.
|
6 |
+
We only ask that you use Jukebox responsibly and clearly indicate your content was created using OpenAI’s Jukebox.
|
7 |
+
|
8 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
|
9 |
+
documentation files (the "Software"), to deal in the Software, including without limitation the rights to use, copy,
|
10 |
+
modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the
|
11 |
+
Software is furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
No portion of the Software, nor any content created with the Software, may be used for commercial purposes.
|
14 |
+
|
15 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
The above copyright notice and this permission notice need not be included with content created by the Software.
|
18 |
+
|
19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
|
20 |
+
WARRANTIES OF MERCHANTABILITY,FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
21 |
+
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
22 |
+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
MANIFEST.in
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
recursive-include jukebox *.py
|
2 |
+
recursive-include jukebox *.txt
|
README.md
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**Status:** Archive (code is provided as-is, no updates expected)
|
2 |
+
|
3 |
+
# Jukebox
|
4 |
+
Code for "Jukebox: A Generative Model for Music"
|
5 |
+
|
6 |
+
[Paper](https://arxiv.org/abs/2005.00341)
|
7 |
+
[Blog](https://openai.com/blog/jukebox)
|
8 |
+
[Explorer](http://jukebox.openai.com/)
|
9 |
+
[Colab](https://colab.research.google.com/github/openai/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb)
|
10 |
+
|
11 |
+
# Install
|
12 |
+
Install the conda package manager from https://docs.conda.io/en/latest/miniconda.html
|
13 |
+
|
14 |
+
```
|
15 |
+
# Required: Sampling
|
16 |
+
conda create --name jukebox python=3.7.5
|
17 |
+
conda activate jukebox
|
18 |
+
conda install mpi4py=3.0.3 # if this fails, try: pip install mpi4py==3.0.3
|
19 |
+
conda install pytorch=1.4 torchvision=0.5 cudatoolkit=10.0 -c pytorch
|
20 |
+
git clone https://github.com/openai/jukebox.git
|
21 |
+
cd jukebox
|
22 |
+
pip install -r requirements.txt
|
23 |
+
pip install -e .
|
24 |
+
|
25 |
+
# Required: Training
|
26 |
+
conda install av=7.0.01 -c conda-forge
|
27 |
+
pip install ./tensorboardX
|
28 |
+
|
29 |
+
# Optional: Apex for faster training with fused_adam
|
30 |
+
conda install pytorch=1.1 torchvision=0.3 cudatoolkit=10.0 -c pytorch
|
31 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex
|
32 |
+
```
|
33 |
+
|
34 |
+
# Sampling
|
35 |
+
## Sampling from scratch
|
36 |
+
To sample normally, run the following command. Model can be `5b`, `5b_lyrics`, `1b_lyrics`
|
37 |
+
```
|
38 |
+
python jukebox/sample.py --model=5b_lyrics --name=sample_5b --levels=3 --sample_length_in_seconds=20 \
|
39 |
+
--total_sample_length_in_seconds=180 --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125
|
40 |
+
```
|
41 |
+
```
|
42 |
+
python jukebox/sample.py --model=1b_lyrics --name=sample_1b --levels=3 --sample_length_in_seconds=20 \
|
43 |
+
--total_sample_length_in_seconds=180 --sr=44100 --n_samples=16 --hop_fraction=0.5,0.5,0.125
|
44 |
+
```
|
45 |
+
The above generates the first `sample_length_in_seconds` seconds of audio from a song of total length `total_sample_length_in_seconds`.
|
46 |
+
To use multiple GPU's, launch the above scripts as `mpiexec -n {ngpus} python jukebox/sample.py ...` so they use `{ngpus}`
|
47 |
+
|
48 |
+
The samples decoded from each level are stored in `{name}/level_{level}`.
|
49 |
+
You can also view the samples as an html with the aligned lyrics under `{name}/level_{level}/index.html`.
|
50 |
+
Run `python -m http.server` and open the html through the server to see the lyrics animate as the song plays.
|
51 |
+
A summary of all sampling data including zs, x, labels and sampling_kwargs is stored in `{name}/level_{level}/data.pth.tar`.
|
52 |
+
|
53 |
+
The hps are for a V100 GPU with 16 GB GPU memory. The `1b_lyrics`, `5b`, and `5b_lyrics` top-level priors take up
|
54 |
+
3.8 GB, 10.3 GB, and 11.5 GB, respectively. The peak memory usage to store transformer key, value cache is about 400 MB
|
55 |
+
for `1b_lyrics` and 1 GB for `5b_lyrics` per sample. If you are having trouble with CUDA OOM issues, try `1b_lyrics` or
|
56 |
+
decrease `max_batch_size` in sample.py, and `--n_samples` in the script call.
|
57 |
+
|
58 |
+
On a V100, it takes about 3 hrs to fully sample 20 seconds of music. Since this is a long time, it is recommended to use `n_samples > 1` so you can generate as many samples as possible in parallel. The 1B lyrics and upsamplers can process 16 samples at a time, while 5B can fit only up to 3. Since the vast majority of time is spent on upsampling, we recommend using a multiple of 3 less than 16 like `--n_samples 15` for `5b_lyrics`. This will make the top-level generate samples in groups of three while upsampling is done in one pass.
|
59 |
+
|
60 |
+
To continue sampling from already generated codes for a longer duration, you can run
|
61 |
+
```
|
62 |
+
python jukebox/sample.py --model=5b_lyrics --name=sample_5b --levels=3 --mode=continue \
|
63 |
+
--codes_file=sample_5b/level_0/data.pth.tar --sample_length_in_seconds=40 --total_sample_length_in_seconds=180 \
|
64 |
+
--sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125
|
65 |
+
```
|
66 |
+
Here, we take the 20 seconds samples saved from the first sampling run at `sample_5b/level_0/data.pth.tar` and continue by adding 20 more seconds.
|
67 |
+
|
68 |
+
You could also continue directly from the level 2 saved outputs, just pass `--codes_file=sample_5b/level_2/data.pth.tar`.
|
69 |
+
Note this will upsample the full 40 seconds song at the end.
|
70 |
+
|
71 |
+
If you stopped sampling at only the first level and want to upsample the saved codes, you can run
|
72 |
+
```
|
73 |
+
python jukebox/sample.py --model=5b_lyrics --name=sample_5b --levels=3 --mode=upsample \
|
74 |
+
--codes_file=sample_5b/level_2/data.pth.tar --sample_length_in_seconds=20 --total_sample_length_in_seconds=180 \
|
75 |
+
--sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125
|
76 |
+
```
|
77 |
+
Here, we take the 20 seconds samples saved from the first sampling run at `sample_5b/level_2/data.pth.tar` and upsample the lower two levels.
|
78 |
+
|
79 |
+
## Prompt with your own music
|
80 |
+
If you want to prompt the model with your own creative piece or any other music, first save them as wave files and run
|
81 |
+
```
|
82 |
+
python jukebox/sample.py --model=5b_lyrics --name=sample_5b_prompted --levels=3 --mode=primed \
|
83 |
+
--audio_file=path/to/recording.wav,awesome-mix.wav,fav-song.wav,etc.wav --prompt_length_in_seconds=12 \
|
84 |
+
--sample_length_in_seconds=20 --total_sample_length_in_seconds=180 --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125
|
85 |
+
```
|
86 |
+
This will load the four files, tile them to fill up to `n_samples` batch size, and prime the model with the first `prompt_length_in_seconds` seconds.
|
87 |
+
|
88 |
+
# Training
|
89 |
+
## VQVAE
|
90 |
+
To train a small vqvae, run
|
91 |
+
```
|
92 |
+
mpiexec -n {ngpus} python jukebox/train.py --hps=small_vqvae --name=small_vqvae --sample_length=262144 --bs=4 \
|
93 |
+
--audio_files_dir={audio_files_dir} --labels=False --train --aug_shift --aug_blend
|
94 |
+
```
|
95 |
+
Here, `{audio_files_dir}` is the directory in which you can put the audio files for your dataset, and `{ngpus}` is number of GPU's you want to use to train.
|
96 |
+
The above trains a two-level VQ-VAE with `downs_t = (5,3)`, and `strides_t = (2, 2)` meaning we downsample the audio by `2**5 = 32` to get the first level of codes, and `2**8 = 256` to get the second level codes.
|
97 |
+
Checkpoints are stored in the `logs` folder. You can monitor the training by running Tensorboard
|
98 |
+
```
|
99 |
+
tensorboard --logdir logs
|
100 |
+
```
|
101 |
+
|
102 |
+
## Prior
|
103 |
+
### Train prior or upsamplers
|
104 |
+
Once the VQ-VAE is trained, we can restore it from its saved checkpoint and train priors on the learnt codes.
|
105 |
+
To train the top-level prior, we can run
|
106 |
+
|
107 |
+
```
|
108 |
+
mpiexec -n {ngpus} python jukebox/train.py --hps=small_vqvae,small_prior,all_fp16,cpu_ema --name=small_prior \
|
109 |
+
--sample_length=2097152 --bs=4 --audio_files_dir={audio_files_dir} --labels=False --train --test --aug_shift --aug_blend \
|
110 |
+
--restore_vqvae=logs/small_vqvae/checkpoint_latest.pth.tar --prior --levels=2 --level=1 --weight_decay=0.01 --save_iters=1000
|
111 |
+
```
|
112 |
+
|
113 |
+
To train the upsampler, we can run
|
114 |
+
```
|
115 |
+
mpiexec -n {ngpus} python jukebox/train.py --hps=small_vqvae,small_upsampler,all_fp16,cpu_ema --name=small_upsampler \
|
116 |
+
--sample_length=262144 --bs=4 --audio_files_dir={audio_files_dir} --labels=False --train --test --aug_shift --aug_blend \
|
117 |
+
--restore_vqvae=logs/small_vqvae/checkpoint_latest.pth.tar --prior --levels=2 --level=0 --weight_decay=0.01 --save_iters=1000
|
118 |
+
```
|
119 |
+
We pass `sample_length = n_ctx * downsample_of_level` so that after downsampling the tokens match the n_ctx of the prior hps.
|
120 |
+
Here, `n_ctx = 8192` and `downsamples = (32, 256)`, giving `sample_lengths = (8192 * 32, 8192 * 256) = (65536, 2097152)` respectively for the bottom and top level.
|
121 |
+
|
122 |
+
### Learning rate annealing
|
123 |
+
To get the best sample quality anneal the learning rate to 0 near the end of training. To do so, continue training from the latest
|
124 |
+
checkpoint and run with
|
125 |
+
```
|
126 |
+
--restore_prior="path/to/checkpoint" --lr_use_linear_decay --lr_start_linear_decay={already_trained_steps} --lr_decay={decay_steps_as_needed}
|
127 |
+
```
|
128 |
+
|
129 |
+
### Reuse pre-trained VQ-VAE and train top-level prior on new dataset from scratch.
|
130 |
+
#### Train without labels
|
131 |
+
Our pre-trained VQ-VAE can produce compressed codes for a wide variety of genres of music, and the pre-trained upsamplers
|
132 |
+
can upsample them back to audio that sound very similar to the original audio.
|
133 |
+
To re-use these for a new dataset of your choice, you can retrain just the top-level
|
134 |
+
|
135 |
+
To train top-level on a new dataset, run
|
136 |
+
```
|
137 |
+
mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_prior \
|
138 |
+
--sample_length=1048576 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
|
139 |
+
--labels=False --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000
|
140 |
+
```
|
141 |
+
Training the `small_prior` with a batch size of 2, 4, and 8 requires 6.7 GB, 9.3 GB, and 15.8 GB of GPU memory, respectively. A few days to a week of training typically yields reasonable samples when the dataset is homogeneous (e.g. all piano pieces, songs of the same style, etc).
|
142 |
+
|
143 |
+
Near the end of training, follow [this](#learning-rate-annealing) to anneal the learning rate to 0
|
144 |
+
|
145 |
+
#### Sample from new model
|
146 |
+
You can then run sample.py with the top-level of our models replaced by your new model. To do so,
|
147 |
+
- Add an entry `my_model=("vqvae", "upsampler_level_0", "upsampler_level_1", "small_prior")` in `MODELS` in `make_models.py`.
|
148 |
+
- Update the `small_prior` dictionary in `hparams.py` to include `restore_prior='path/to/checkpoint'`. If you
|
149 |
+
you changed any hps directly in the command line script (eg:`heads`), make sure to update them in the dictionary too so
|
150 |
+
that `make_models` restores our checkpoint correctly.
|
151 |
+
- Run sample.py as outlined in the sampling section, but now with `--model=my_model`
|
152 |
+
|
153 |
+
For example, let's say we trained `small_vqvae`, `small_prior`, and `small_upsampler` under `/path/to/jukebox/logs`. In `make_models.py`, we are going to declare a tuple of the new models as `my_model`.
|
154 |
+
```
|
155 |
+
MODELS = {
|
156 |
+
'5b': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b"),
|
157 |
+
'5b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b_lyrics"),
|
158 |
+
'1b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_1b_lyrics"),
|
159 |
+
'my_model': ("my_small_vqvae", "my_small_upsampler", "my_small_prior"),
|
160 |
+
}
|
161 |
+
```
|
162 |
+
|
163 |
+
Next, in `hparams.py`, we add them to the registry with the corresponding `restore_`paths and any other command line options used during training. Another important note is that for top-level priors with lyric conditioning, we have to locate a self-attention layer that shows alignment between the lyric and music tokens. Look for layers where `prior.prior.transformer._attn_mods[layer].attn_func` is either 6 or 7. If your model is starting to sing along lyrics, it means some layer, head pair has learned alignment. Congrats!
|
164 |
+
```
|
165 |
+
my_small_vqvae = Hyperparams(
|
166 |
+
restore_vqvae='/path/to/jukebox/logs/small_vqvae/checkpoint_some_step.pth.tar',
|
167 |
+
)
|
168 |
+
my_small_vqvae.update(small_vqvae)
|
169 |
+
HPARAMS_REGISTRY["my_small_vqvae"] = my_small_vqvae
|
170 |
+
|
171 |
+
my_small_prior = Hyperparams(
|
172 |
+
restore_prior='/path/to/jukebox/logs/small_prior/checkpoint_latest.pth.tar',
|
173 |
+
level=1,
|
174 |
+
labels=False,
|
175 |
+
# TODO For the two lines below, if `--labels` was used and the model is
|
176 |
+
# trained with lyrics, find and enter the layer, head pair that has learned
|
177 |
+
# alignment.
|
178 |
+
alignment_layer=47,
|
179 |
+
alignment_head=0,
|
180 |
+
)
|
181 |
+
my_small_prior.update(small_prior)
|
182 |
+
HPARAMS_REGISTRY["my_small_prior"] = my_small_prior
|
183 |
+
|
184 |
+
my_small_upsampler = Hyperparams(
|
185 |
+
restore_prior='/path/to/jukebox/logs/small_upsampler/checkpoint_latest.pth.tar',
|
186 |
+
level=0,
|
187 |
+
labels=False,
|
188 |
+
)
|
189 |
+
my_small_upsampler.update(small_upsampler)
|
190 |
+
HPARAMS_REGISTRY["my_small_upsampler"] = my_small_upsampler
|
191 |
+
```
|
192 |
+
|
193 |
+
#### Train with labels
|
194 |
+
To train with you own metadata for your audio files, implement `get_metadata` in `data/files_dataset.py` to return the
|
195 |
+
`artist`, `genre` and `lyrics` for a given audio file. For now, you can pass `''` for lyrics to not use any lyrics.
|
196 |
+
|
197 |
+
For training with labels, we'll use `small_labelled_prior` in `hparams.py`, and we set `labels=True,labels_v3=True`.
|
198 |
+
We use 2 kinds of labels information:
|
199 |
+
- Artist/Genre:
|
200 |
+
- For each file, we return an artist_id and a list of genre_ids. The reason we have a list and not a single genre_id
|
201 |
+
is that in v2, we split genres like `blues_rock` into a bag of words `[blues, rock]`, and we pass atmost
|
202 |
+
`max_bow_genre_size` of those, in `v3` we consider it as a single word and just set `max_bow_genre_size=1`.
|
203 |
+
- Update the `v3_artist_ids` and `v3_genre_ids` to use ids from your new dataset.
|
204 |
+
- In `small_labelled_prior`, set the hps `y_bins = (number_of_genres, number_of_artists)` and `max_bow_genre_size=1`.
|
205 |
+
- Timing:
|
206 |
+
- For each chunk of audio, we return the `total_length` of the song, the `offset` the current audio chunk is at and
|
207 |
+
the `sample_length` of the audio chunk. We have three timing embeddings: total_length, our current position, and our
|
208 |
+
current position as a fraction of the total length, and we divide the range of these values into `t_bins` discrete bins.
|
209 |
+
- In `small_labelled_prior`, set the hps `min_duration` and `max_duration` to be the shortest/longest duration of audio
|
210 |
+
files you want for your dataset, and `t_bins` for how many bins you want to discretize timing information into. Note
|
211 |
+
`min_duration * sr` needs to be at least `sample_length` to have an audio chunk in it.
|
212 |
+
|
213 |
+
After these modifications, to train a top-level with labels, run
|
214 |
+
```
|
215 |
+
mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_labelled_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_prior_labels \
|
216 |
+
--sample_length=1048576 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
|
217 |
+
--labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000
|
218 |
+
```
|
219 |
+
|
220 |
+
For sampling, follow same instructions as [above](#sample-from-new-model) but use `small_labelled_prior` instead of `small_prior`.
|
221 |
+
|
222 |
+
#### Train with lyrics
|
223 |
+
To train in addition with lyrics, update `get_metadata` in `data/files_dataset.py` to return `lyrics` too.
|
224 |
+
For training with lyrics, we'll use `small_single_enc_dec_prior` in `hparams.py`.
|
225 |
+
- Lyrics:
|
226 |
+
- For each file, we linearly align the lyric characters to the audio, find the position in lyric that corresponds to
|
227 |
+
the midpoint of our audio chunk, and pass a window of `n_tokens` lyric characters centred around that.
|
228 |
+
- In `small_single_enc_dec_prior`, set the hps `use_tokens=True` and `n_tokens` to be the number of lyric characters
|
229 |
+
to use for an audio chunk. Set it according to the `sample_length` you're training on so that its large enough that
|
230 |
+
the lyrics for an audio chunk are almost always found inside a window of that size.
|
231 |
+
- If you use a non-English vocabulary, update `text_processor.py` with your new vocab and set
|
232 |
+
`n_vocab = number of characters in vocabulary` accordingly in `small_single_enc_dec_prior`. In v2, we had a `n_vocab=80`
|
233 |
+
and in v3 we missed `+` and so `n_vocab=79` of characters.
|
234 |
+
|
235 |
+
After these modifications, to train a top-level with labels and lyrics, run
|
236 |
+
```
|
237 |
+
mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_single_enc_dec_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_single_enc_dec_prior_labels \
|
238 |
+
--sample_length=786432 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
|
239 |
+
--labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000
|
240 |
+
```
|
241 |
+
To simplify hps choices, here we used a `single_enc_dec` model like the `1b_lyrics` model that combines both encoder and
|
242 |
+
decoder of the transformer into a single model. We do so by merging the lyric vocab and vq-vae vocab into a single
|
243 |
+
larger vocab, and flattening the lyric tokens and the vq-vae codes into a single sequence of length `n_ctx + n_tokens`.
|
244 |
+
This uses `attn_order=12` which includes `prime_attention` layers with keys/values from lyrics and queries from audio.
|
245 |
+
If you instead want to use a model with the usual encoder-decoder style transformer, use `small_sep_enc_dec_prior`.
|
246 |
+
|
247 |
+
For sampling, follow same instructions as [above](#sample-from-new-model) but use `small_single_enc_dec_prior` instead of
|
248 |
+
`small_prior`. To also get the alignment between lyrics and samples in the saved html, you'll need to set `alignment_layer`
|
249 |
+
and `alignment_head` in `small_single_enc_dec_prior`. To find which layer/head is best to use, run a forward pass on a training example,
|
250 |
+
save the attention weight tensors for all prime_attention layers, and pick the (layer, head) which has the best linear alignment
|
251 |
+
pattern between the lyrics keys and music queries.
|
252 |
+
|
253 |
+
### Fine-tune pre-trained top-level prior to new style(s)
|
254 |
+
Previously, we showed how to train a small top-level prior from scratch. Assuming you have a GPU with at least 15 GB of memory and support for fp16, you could fine-tune from our pre-trained 1B top-level prior. Here are the steps:
|
255 |
+
|
256 |
+
- Support `--labels=True` by implementing `get_metadata` in `jukebox/data/files_dataset.py` for your dataset.
|
257 |
+
- Add new entries in `jukebox/data/ids`. We recommend replacing existing mappings (e.g. rename `"unknown"`, etc with styles of your choice). This uses the pre-trained style vectors as initialization and could potentially save some compute.
|
258 |
+
|
259 |
+
After these modifications, run
|
260 |
+
```
|
261 |
+
mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,prior_1b_lyrics,all_fp16,cpu_ema --name=finetuned \
|
262 |
+
--sample_length=1048576 --bs=1 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
|
263 |
+
--labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000
|
264 |
+
```
|
265 |
+
To get the best sample quality, it is recommended to anneal the learning rate in the end. Training the 5B top-level requires GPipe which is not supported in this release.
|
266 |
+
|
267 |
+
# Citation
|
268 |
+
|
269 |
+
Please cite using the following bibtex entry:
|
270 |
+
|
271 |
+
```
|
272 |
+
@article{dhariwal2020jukebox,
|
273 |
+
title={Jukebox: A Generative Model for Music},
|
274 |
+
author={Dhariwal, Prafulla and Jun, Heewoo and Payne, Christine and Kim, Jong Wook and Radford, Alec and Sutskever, Ilya},
|
275 |
+
journal={arXiv preprint arXiv:2005.00341},
|
276 |
+
year={2020}
|
277 |
+
}
|
278 |
+
```
|
279 |
+
|
280 |
+
# License
|
281 |
+
[Noncommercial Use License](./LICENSE)
|
282 |
+
|
283 |
+
It covers both released code and weights.
|
284 |
+
|
apex/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
apex.egg-info
|
2 |
+
dist
|
3 |
+
build
|
4 |
+
docs/build
|
5 |
+
*~
|
apex/.nojekyll
ADDED
File without changes
|
apex/LICENSE
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
All rights reserved.
|
2 |
+
|
3 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
4 |
+
|
5 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
6 |
+
|
7 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
8 |
+
|
9 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
10 |
+
|
11 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
apex/README.md
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Introduction
|
2 |
+
|
3 |
+
This repository holds NVIDIA-maintained utilities to streamline
|
4 |
+
mixed precision and distributed training in Pytorch.
|
5 |
+
Some of the code here will be included in upstream Pytorch eventually.
|
6 |
+
The intention of Apex is to make up-to-date utilities available to
|
7 |
+
users as quickly as possible.
|
8 |
+
|
9 |
+
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
|
10 |
+
|
11 |
+
# Contents
|
12 |
+
|
13 |
+
## 1. Amp: Automatic Mixed Precision
|
14 |
+
|
15 |
+
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
|
16 |
+
Users can easily experiment with different pure and mixed precision training modes by supplying
|
17 |
+
different flags to `amp.initialize`.
|
18 |
+
|
19 |
+
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
|
20 |
+
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
|
21 |
+
|
22 |
+
[API Documentation](https://nvidia.github.io/apex/amp.html)
|
23 |
+
|
24 |
+
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
|
25 |
+
|
26 |
+
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
|
27 |
+
|
28 |
+
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
|
29 |
+
|
30 |
+
## 2. Distributed Training
|
31 |
+
|
32 |
+
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
|
33 |
+
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
|
34 |
+
optimized for NVIDIA's NCCL communication library.
|
35 |
+
|
36 |
+
[API Documentation](https://nvidia.github.io/apex/parallel.html)
|
37 |
+
|
38 |
+
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
|
39 |
+
|
40 |
+
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
|
41 |
+
|
42 |
+
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
|
43 |
+
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
|
44 |
+
|
45 |
+
### Synchronized Batch Normalization
|
46 |
+
|
47 |
+
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
|
48 |
+
support synchronized BN.
|
49 |
+
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
|
50 |
+
Synchronous BN has been used in cases where only a small
|
51 |
+
local minibatch can fit on each GPU.
|
52 |
+
Allreduced stats increase the effective batch size for the BN layer to the
|
53 |
+
global batch size across all processes (which, technically, is the correct
|
54 |
+
formulation).
|
55 |
+
Synchronous BN has been observed to improve converged accuracy in some of our research models.
|
56 |
+
|
57 |
+
# Requirements
|
58 |
+
|
59 |
+
Python 3
|
60 |
+
|
61 |
+
CUDA 9 or newer
|
62 |
+
|
63 |
+
PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.
|
64 |
+
|
65 |
+
We recommend the latest stable release, obtainable from
|
66 |
+
[https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
|
67 |
+
|
68 |
+
It's often convenient to use Apex in Docker containers. Compatible options include:
|
69 |
+
* [NVIDIA Pytorch containers from NGC](https://ngc.nvidia.com/catalog/containers/nvidia%2Fpytorch), which come with Apex preinstalled. To use the latest Amp API, you may need to `pip uninstall apex` then reinstall Apex using the **Quick Start** commands below.
|
70 |
+
* [official Pytorch -devel Dockerfiles](https://hub.docker.com/r/pytorch/pytorch/tags), e.g. `docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7`, in which you can install Apex using the **Quick Start** commands.
|
71 |
+
|
72 |
+
See the [Docker example folder](https://github.com/NVIDIA/apex/tree/master/examples/docker) for details.
|
73 |
+
|
74 |
+
# Quick Start
|
75 |
+
|
76 |
+
### Linux
|
77 |
+
|
78 |
+
For performance and full functionality, we recommend installing Apex with
|
79 |
+
CUDA and C++ extensions via
|
80 |
+
```
|
81 |
+
$ git clone https://github.com/NVIDIA/apex
|
82 |
+
$ cd apex
|
83 |
+
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
|
84 |
+
```
|
85 |
+
|
86 |
+
Apex also supports a Python-only build (required with Pytorch 0.4) via
|
87 |
+
```
|
88 |
+
$ pip install -v --no-cache-dir .
|
89 |
+
```
|
90 |
+
A Python-only build omits:
|
91 |
+
- Fused kernels required to use `apex.optimizers.FusedAdam`.
|
92 |
+
- Fused kernels required to use `apex.normalization.FusedLayerNorm`.
|
93 |
+
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
|
94 |
+
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
|
95 |
+
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
|
96 |
+
|
97 |
+
### Windows support
|
98 |
+
Windows support is experimental, and Linux is recommended. `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
|
99 |
+
on your system. `pip install -v --no-cache-dir .` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
|
apex/apex.patch
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diff --git a/csrc/fused_adam_cuda_kernel.cu b/csrc/fused_adam_cuda_kernel.cu
|
2 |
+
index 34f7aa2..95581d1 100644
|
3 |
+
--- a/csrc/fused_adam_cuda_kernel.cu
|
4 |
+
+++ b/csrc/fused_adam_cuda_kernel.cu
|
5 |
+
@@ -19,8 +19,8 @@ typedef enum{
|
6 |
+
|
7 |
+
template <typename T, typename GRAD_T>
|
8 |
+
__global__ void adam_cuda_kernel(
|
9 |
+
- T* __restrict__ p,
|
10 |
+
- GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
|
11 |
+
+ GRAD_T* __restrict__ p,
|
12 |
+
+ T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
|
13 |
+
T* __restrict__ m,
|
14 |
+
T* __restrict__ v,
|
15 |
+
const GRAD_T * __restrict__ g,
|
16 |
+
@@ -50,7 +50,7 @@ __global__ void adam_cuda_kernel(
|
17 |
+
else // Mode 1
|
18 |
+
denom = sqrtf(v[j]) + eps;
|
19 |
+
float update = (m[j]/denom) + (decay*p[j]);
|
20 |
+
- p[j] = p[j] - (step_size*update);
|
21 |
+
+ p[j] = (GRAD_T) (p[j] - (step_size*update));
|
22 |
+
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
|
23 |
+
}
|
24 |
+
}
|
25 |
+
@@ -93,14 +93,14 @@ void fused_adam_cuda(
|
26 |
+
|
27 |
+
if (g.scalar_type() == at::ScalarType::Half) {
|
28 |
+
//all other values should be fp32 for half gradients
|
29 |
+
- AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
|
30 |
+
+// AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
|
31 |
+
//dispatch is done on the gradient type
|
32 |
+
using namespace at; // prevents "toString is undefined" errors
|
33 |
+
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
|
34 |
+
using accscalar_t = at::acc_type<scalar_t_0, true>;
|
35 |
+
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
|
36 |
+
- p.data<accscalar_t>(),
|
37 |
+
- p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL,
|
38 |
+
+ p.data<scalar_t_0>(),
|
39 |
+
+ NULL, //don't output p_copy for fp32, it's wasted write
|
40 |
+
m.data<accscalar_t>(),
|
41 |
+
v.data<accscalar_t>(),
|
42 |
+
g.data<scalar_t_0>(),
|
apex/apex/RNN/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Under construction...
|
apex/apex/RNN/RNNBackend.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.autograd import Variable
|
4 |
+
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
|
10 |
+
def is_iterable(maybe_iterable):
|
11 |
+
return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple)
|
12 |
+
|
13 |
+
|
14 |
+
def flatten_list(tens_list):
|
15 |
+
"""
|
16 |
+
flatten_list
|
17 |
+
"""
|
18 |
+
if not is_iterable(tens_list):
|
19 |
+
return tens_list
|
20 |
+
|
21 |
+
return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() )
|
22 |
+
|
23 |
+
|
24 |
+
#These modules always assumes batch_first
|
25 |
+
class bidirectionalRNN(nn.Module):
|
26 |
+
"""
|
27 |
+
bidirectionalRNN
|
28 |
+
"""
|
29 |
+
def __init__(self, inputRNN, num_layers=1, dropout = 0):
|
30 |
+
super(bidirectionalRNN, self).__init__()
|
31 |
+
self.dropout = dropout
|
32 |
+
self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout)
|
33 |
+
self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout)
|
34 |
+
self.rnns = nn.ModuleList([self.fwd, self.bckwrd])
|
35 |
+
|
36 |
+
#collect hidden option will return all hidden/cell states from entire RNN
|
37 |
+
def forward(self, input, collect_hidden=False):
|
38 |
+
"""
|
39 |
+
forward()
|
40 |
+
"""
|
41 |
+
seq_len = input.size(0)
|
42 |
+
bsz = input.size(1)
|
43 |
+
|
44 |
+
fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden))
|
45 |
+
bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden))
|
46 |
+
|
47 |
+
output = torch.cat( [fwd_out, bckwrd_out], -1 )
|
48 |
+
hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) )
|
49 |
+
|
50 |
+
return output, hiddens
|
51 |
+
|
52 |
+
def reset_parameters(self):
|
53 |
+
"""
|
54 |
+
reset_parameters()
|
55 |
+
"""
|
56 |
+
for rnn in self.rnns:
|
57 |
+
rnn.reset_parameters()
|
58 |
+
|
59 |
+
def init_hidden(self, bsz):
|
60 |
+
"""
|
61 |
+
init_hidden()
|
62 |
+
"""
|
63 |
+
for rnn in self.rnns:
|
64 |
+
rnn.init_hidden(bsz)
|
65 |
+
|
66 |
+
def detach_hidden(self):
|
67 |
+
"""
|
68 |
+
detach_hidden()
|
69 |
+
"""
|
70 |
+
for rnn in self.rnns:
|
71 |
+
rnn.detachHidden()
|
72 |
+
|
73 |
+
def reset_hidden(self, bsz):
|
74 |
+
"""
|
75 |
+
reset_hidden()
|
76 |
+
"""
|
77 |
+
for rnn in self.rnns:
|
78 |
+
rnn.reset_hidden(bsz)
|
79 |
+
|
80 |
+
def init_inference(self, bsz):
|
81 |
+
"""
|
82 |
+
init_inference()
|
83 |
+
"""
|
84 |
+
for rnn in self.rnns:
|
85 |
+
rnn.init_inference(bsz)
|
86 |
+
|
87 |
+
|
88 |
+
#assumes hidden_state[0] of inputRNN is output hidden state
|
89 |
+
#constructor either takes an RNNCell or list of RNN layers
|
90 |
+
class stackedRNN(nn.Module):
|
91 |
+
"""
|
92 |
+
stackedRNN
|
93 |
+
"""
|
94 |
+
def __init__(self, inputRNN, num_layers=1, dropout=0):
|
95 |
+
super(stackedRNN, self).__init__()
|
96 |
+
|
97 |
+
self.dropout = dropout
|
98 |
+
|
99 |
+
if isinstance(inputRNN, RNNCell):
|
100 |
+
self.rnns = [inputRNN]
|
101 |
+
for i in range(num_layers-1):
|
102 |
+
self.rnns.append(inputRNN.new_like(inputRNN.output_size))
|
103 |
+
elif isinstance(inputRNN, list):
|
104 |
+
assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers"
|
105 |
+
self.rnns=inputRNN
|
106 |
+
else:
|
107 |
+
raise RuntimeError()
|
108 |
+
|
109 |
+
self.nLayers = len(self.rnns)
|
110 |
+
|
111 |
+
self.rnns = nn.ModuleList(self.rnns)
|
112 |
+
|
113 |
+
|
114 |
+
'''
|
115 |
+
Returns output as hidden_state[0] Tensor([sequence steps][batch size][features])
|
116 |
+
If collect hidden will also return Tuple(
|
117 |
+
[n_hidden_states][sequence steps] Tensor([layer][batch size][features])
|
118 |
+
)
|
119 |
+
If not collect hidden will also return Tuple(
|
120 |
+
[n_hidden_states] Tensor([layer][batch size][features])
|
121 |
+
'''
|
122 |
+
def forward(self, input, collect_hidden=False, reverse=False):
|
123 |
+
"""
|
124 |
+
forward()
|
125 |
+
"""
|
126 |
+
seq_len = input.size(0)
|
127 |
+
bsz = input.size(1)
|
128 |
+
inp_iter = reversed(range(seq_len)) if reverse else range(seq_len)
|
129 |
+
|
130 |
+
hidden_states = [[] for i in range(self.nLayers)]
|
131 |
+
outputs = []
|
132 |
+
|
133 |
+
for seq in inp_iter:
|
134 |
+
for layer in range(self.nLayers):
|
135 |
+
|
136 |
+
if layer == 0:
|
137 |
+
prev_out = input[seq]
|
138 |
+
|
139 |
+
outs = self.rnns[layer](prev_out)
|
140 |
+
|
141 |
+
if collect_hidden:
|
142 |
+
hidden_states[layer].append(outs)
|
143 |
+
elif seq == seq_len-1:
|
144 |
+
hidden_states[layer].append(outs)
|
145 |
+
|
146 |
+
prev_out = outs[0]
|
147 |
+
|
148 |
+
outputs.append(prev_out)
|
149 |
+
|
150 |
+
if reverse:
|
151 |
+
outputs = list(reversed(outputs))
|
152 |
+
'''
|
153 |
+
At this point outputs is in format:
|
154 |
+
list( [seq_length] x Tensor([bsz][features]) )
|
155 |
+
need to convert it to:
|
156 |
+
list( Tensor([seq_length][bsz][features]) )
|
157 |
+
'''
|
158 |
+
output = flatten_list(outputs)
|
159 |
+
|
160 |
+
'''
|
161 |
+
hidden_states at this point is in format:
|
162 |
+
list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) )
|
163 |
+
need to convert it to:
|
164 |
+
For not collect hidden:
|
165 |
+
list( [hidden_states] x Tensor([layer][bsz][features]) )
|
166 |
+
For collect hidden:
|
167 |
+
list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
|
168 |
+
'''
|
169 |
+
if not collect_hidden:
|
170 |
+
seq_len = 1
|
171 |
+
n_hid = self.rnns[0].n_hidden_states
|
172 |
+
new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ]
|
173 |
+
|
174 |
+
|
175 |
+
for i in range(n_hid):
|
176 |
+
for j in range(seq_len):
|
177 |
+
for k in range(self.nLayers):
|
178 |
+
new_hidden[i][j][k] = hidden_states[k][j][i]
|
179 |
+
|
180 |
+
hidden_states = new_hidden
|
181 |
+
#Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) )
|
182 |
+
#Reverse seq_length if reverse
|
183 |
+
if reverse:
|
184 |
+
hidden_states = list( list(reversed(list(entry))) for entry in hidden_states)
|
185 |
+
|
186 |
+
#flatten layer dimension into tensor
|
187 |
+
hiddens = list( list(
|
188 |
+
flatten_list(seq) for seq in hidden )
|
189 |
+
for hidden in hidden_states )
|
190 |
+
|
191 |
+
#Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
|
192 |
+
#Remove seq_length dimension if not collect_hidden
|
193 |
+
if not collect_hidden:
|
194 |
+
hidden_states = list( entry[0] for entry in hidden_states)
|
195 |
+
return output, hidden_states
|
196 |
+
|
197 |
+
def reset_parameters(self):
|
198 |
+
"""
|
199 |
+
reset_parameters()
|
200 |
+
"""
|
201 |
+
for rnn in self.rnns:
|
202 |
+
rnn.reset_parameters()
|
203 |
+
|
204 |
+
def init_hidden(self, bsz):
|
205 |
+
"""
|
206 |
+
init_hidden()
|
207 |
+
"""
|
208 |
+
for rnn in self.rnns:
|
209 |
+
rnn.init_hidden(bsz)
|
210 |
+
|
211 |
+
def detach_hidden(self):
|
212 |
+
"""
|
213 |
+
detach_hidden()
|
214 |
+
"""
|
215 |
+
for rnn in self.rnns:
|
216 |
+
rnn.detach_hidden()
|
217 |
+
|
218 |
+
def reset_hidden(self, bsz):
|
219 |
+
"""
|
220 |
+
reset_hidden()
|
221 |
+
"""
|
222 |
+
for rnn in self.rnns:
|
223 |
+
rnn.reset_hidden(bsz)
|
224 |
+
|
225 |
+
def init_inference(self, bsz):
|
226 |
+
"""
|
227 |
+
init_inference()
|
228 |
+
"""
|
229 |
+
for rnn in self.rnns:
|
230 |
+
rnn.init_inference(bsz)
|
231 |
+
|
232 |
+
class RNNCell(nn.Module):
|
233 |
+
"""
|
234 |
+
RNNCell
|
235 |
+
gate_multiplier is related to the architecture you're working with
|
236 |
+
For LSTM-like it will be 4 and GRU-like will be 3.
|
237 |
+
Always assumes input is NOT batch_first.
|
238 |
+
Output size that's not hidden size will use output projection
|
239 |
+
Hidden_states is number of hidden states that are needed for cell
|
240 |
+
if one will go directly to cell as tensor, if more will go as list
|
241 |
+
"""
|
242 |
+
def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None):
|
243 |
+
super(RNNCell, self).__init__()
|
244 |
+
|
245 |
+
self.gate_multiplier = gate_multiplier
|
246 |
+
self.input_size = input_size
|
247 |
+
self.hidden_size = hidden_size
|
248 |
+
self.cell = cell
|
249 |
+
self.bias = bias
|
250 |
+
self.output_size = output_size
|
251 |
+
if output_size is None:
|
252 |
+
self.output_size = hidden_size
|
253 |
+
|
254 |
+
self.gate_size = gate_multiplier * self.hidden_size
|
255 |
+
self.n_hidden_states = n_hidden_states
|
256 |
+
|
257 |
+
self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size))
|
258 |
+
self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.output_size))
|
259 |
+
|
260 |
+
#Check if there's recurrent projection
|
261 |
+
if(self.output_size != self.hidden_size):
|
262 |
+
self.w_ho = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size))
|
263 |
+
|
264 |
+
self.b_ih = self.b_hh = None
|
265 |
+
if self.bias:
|
266 |
+
self.b_ih = nn.Parameter(torch.Tensor(self.gate_size))
|
267 |
+
self.b_hh = nn.Parameter(torch.Tensor(self.gate_size))
|
268 |
+
|
269 |
+
#hidden states for forward
|
270 |
+
self.hidden = [ None for states in range(self.n_hidden_states)]
|
271 |
+
|
272 |
+
self.reset_parameters()
|
273 |
+
|
274 |
+
def new_like(self, new_input_size=None):
|
275 |
+
"""
|
276 |
+
new_like()
|
277 |
+
"""
|
278 |
+
if new_input_size is None:
|
279 |
+
new_input_size = self.input_size
|
280 |
+
|
281 |
+
return type(self)(self.gate_multiplier,
|
282 |
+
new_input_size,
|
283 |
+
self.hidden_size,
|
284 |
+
self.cell,
|
285 |
+
self.n_hidden_states,
|
286 |
+
self.bias,
|
287 |
+
self.output_size)
|
288 |
+
|
289 |
+
|
290 |
+
#Use xavier where we can (weights), otherwise use uniform (bias)
|
291 |
+
def reset_parameters(self, gain=1):
|
292 |
+
"""
|
293 |
+
reset_parameters()
|
294 |
+
"""
|
295 |
+
stdev = 1.0 / math.sqrt(self.hidden_size)
|
296 |
+
for param in self.parameters():
|
297 |
+
param.data.uniform_(-stdev, stdev)
|
298 |
+
'''
|
299 |
+
Xavier reset:
|
300 |
+
def reset_parameters(self, gain=1):
|
301 |
+
stdv = 1.0 / math.sqrt(self.gate_size)
|
302 |
+
|
303 |
+
for param in self.parameters():
|
304 |
+
if (param.dim() > 1):
|
305 |
+
torch.nn.init.xavier_normal(param, gain)
|
306 |
+
else:
|
307 |
+
param.data.uniform_(-stdv, stdv)
|
308 |
+
'''
|
309 |
+
def init_hidden(self, bsz):
|
310 |
+
"""
|
311 |
+
init_hidden()
|
312 |
+
"""
|
313 |
+
for param in self.parameters():
|
314 |
+
if param is not None:
|
315 |
+
a_param = param
|
316 |
+
break
|
317 |
+
|
318 |
+
for i, _ in enumerate(self.hidden):
|
319 |
+
if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz):
|
320 |
+
|
321 |
+
if i==0:
|
322 |
+
hidden_size = self.output_size
|
323 |
+
else:
|
324 |
+
hidden_size = self.hidden_size
|
325 |
+
|
326 |
+
tens = a_param.data.new(bsz, hidden_size).zero_()
|
327 |
+
self.hidden[i] = Variable(tens, requires_grad=False)
|
328 |
+
|
329 |
+
|
330 |
+
def reset_hidden(self, bsz):
|
331 |
+
"""
|
332 |
+
reset_hidden()
|
333 |
+
"""
|
334 |
+
for i, _ in enumerate(self.hidden):
|
335 |
+
self.hidden[i] = None
|
336 |
+
self.init_hidden(bsz)
|
337 |
+
|
338 |
+
def detach_hidden(self):
|
339 |
+
"""
|
340 |
+
detach_hidden()
|
341 |
+
"""
|
342 |
+
for i, _ in enumerate(self.hidden):
|
343 |
+
if self.hidden[i] is None:
|
344 |
+
raise RuntimeError("Must initialize hidden state before you can detach it")
|
345 |
+
for i, _ in enumerate(self.hidden):
|
346 |
+
self.hidden[i] = self.hidden[i].detach()
|
347 |
+
|
348 |
+
def forward(self, input):
|
349 |
+
"""
|
350 |
+
forward()
|
351 |
+
if not inited or bsz has changed this will create hidden states
|
352 |
+
"""
|
353 |
+
self.init_hidden(input.size()[0])
|
354 |
+
|
355 |
+
hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
|
356 |
+
self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh)
|
357 |
+
if(self.n_hidden_states > 1):
|
358 |
+
self.hidden = list(self.hidden)
|
359 |
+
else:
|
360 |
+
self.hidden=[self.hidden]
|
361 |
+
|
362 |
+
if self.output_size != self.hidden_size:
|
363 |
+
self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
|
364 |
+
|
365 |
+
return tuple(self.hidden)
|
apex/apex/RNN/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .models import LSTM, GRU, ReLU, Tanh, mLSTM
|
2 |
+
|
3 |
+
__all__ = ['models']
|
apex/apex/RNN/cells.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .RNNBackend import RNNCell
|
6 |
+
|
7 |
+
from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
class mLSTMRNNCell(RNNCell):
|
13 |
+
"""
|
14 |
+
mLSTMRNNCell
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, input_size, hidden_size, bias = False, output_size = None):
|
18 |
+
gate_multiplier = 4
|
19 |
+
super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)
|
20 |
+
|
21 |
+
self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size))
|
22 |
+
self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size))
|
23 |
+
|
24 |
+
self.reset_parameters()
|
25 |
+
|
26 |
+
def forward(self, input):
|
27 |
+
"""
|
28 |
+
mLSTMRNNCell.forward()
|
29 |
+
"""
|
30 |
+
#if not inited or bsz has changed this will create hidden states
|
31 |
+
self.init_hidden(input.size()[0])
|
32 |
+
|
33 |
+
hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
|
34 |
+
|
35 |
+
self.hidden = list(
|
36 |
+
self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh,
|
37 |
+
b_ih=self.b_ih, b_hh=self.b_hh)
|
38 |
+
)
|
39 |
+
|
40 |
+
if self.output_size != self.hidden_size:
|
41 |
+
self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
|
42 |
+
return tuple(self.hidden)
|
43 |
+
|
44 |
+
|
45 |
+
def new_like(self, new_input_size=None):
|
46 |
+
if new_input_size is None:
|
47 |
+
new_input_size = self.input_size
|
48 |
+
|
49 |
+
return type(self)(
|
50 |
+
new_input_size,
|
51 |
+
self.hidden_size,
|
52 |
+
self.bias,
|
53 |
+
self.output_size)
|
54 |
+
|
55 |
+
def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None):
|
56 |
+
"""
|
57 |
+
mLSTMCell
|
58 |
+
"""
|
59 |
+
|
60 |
+
if input.is_cuda:
|
61 |
+
igates = F.linear(input, w_ih)
|
62 |
+
m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
|
63 |
+
hgates = F.linear(m, w_hh)
|
64 |
+
|
65 |
+
state = fusedBackend.LSTMFused.apply
|
66 |
+
return state(igates, hgates, hidden[1], b_ih, b_hh)
|
67 |
+
|
68 |
+
hx, cx = hidden
|
69 |
+
|
70 |
+
m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
|
71 |
+
gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh)
|
72 |
+
|
73 |
+
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
74 |
+
|
75 |
+
ingate = F.sigmoid(ingate)
|
76 |
+
forgetgate = F.sigmoid(forgetgate)
|
77 |
+
cellgate = F.tanh(cellgate)
|
78 |
+
outgate = F.sigmoid(outgate)
|
79 |
+
|
80 |
+
cy = (forgetgate * cx) + (ingate * cellgate)
|
81 |
+
hy = outgate * F.tanh(cy)
|
82 |
+
|
83 |
+
return hy, cy
|
84 |
+
|
apex/apex/RNN/models.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell
|
4 |
+
|
5 |
+
from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell
|
6 |
+
from .cells import mLSTMRNNCell, mLSTMCell
|
7 |
+
|
8 |
+
def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0):
|
9 |
+
"""
|
10 |
+
:class:`toRNNBackend`
|
11 |
+
"""
|
12 |
+
|
13 |
+
if bidirectional:
|
14 |
+
return bidirectionalRNN(inputRNN, num_layers, dropout = dropout)
|
15 |
+
else:
|
16 |
+
return stackedRNN(inputRNN, num_layers, dropout = dropout)
|
17 |
+
|
18 |
+
|
19 |
+
def LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
20 |
+
"""
|
21 |
+
:class:`LSTM`
|
22 |
+
"""
|
23 |
+
inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size)
|
24 |
+
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
25 |
+
|
26 |
+
def GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
27 |
+
"""
|
28 |
+
:class:`GRU`
|
29 |
+
"""
|
30 |
+
inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size)
|
31 |
+
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
32 |
+
|
33 |
+
def ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
34 |
+
"""
|
35 |
+
:class:`ReLU`
|
36 |
+
"""
|
37 |
+
inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size)
|
38 |
+
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
39 |
+
|
40 |
+
def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
41 |
+
"""
|
42 |
+
:class:`Tanh`
|
43 |
+
"""
|
44 |
+
inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size)
|
45 |
+
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
46 |
+
|
47 |
+
def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
|
48 |
+
"""
|
49 |
+
:class:`mLSTM`
|
50 |
+
"""
|
51 |
+
inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size)
|
52 |
+
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
|
53 |
+
|
54 |
+
|
apex/apex/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import parallel
|
2 |
+
from . import amp
|
3 |
+
from . import fp16_utils
|
4 |
+
|
5 |
+
# For optimizers and normalization there is no Python fallback.
|
6 |
+
# Absence of cuda backend is a hard error.
|
7 |
+
# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
|
8 |
+
# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
|
9 |
+
# so they expect those backends to be available, but for some reason they actually aren't
|
10 |
+
# available (for example because they built improperly in a way that isn't revealed until
|
11 |
+
# load time) the error message is timely and visible.
|
12 |
+
from . import optimizers
|
13 |
+
from . import normalization
|
apex/apex/amp/README.md
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# amp: Automatic Mixed Precision
|
2 |
+
|
3 |
+
## Annotating User Functions
|
4 |
+
|
5 |
+
Nearly all PyTorch user code needs nothing more than the two steps
|
6 |
+
above to use amp. After all, custom layers are built out of simpler
|
7 |
+
PyTorch components, and amp already can see those.
|
8 |
+
|
9 |
+
However, any custom C++ or CUDA code is outside of amp's (default)
|
10 |
+
view of things. For example, suppose I implemented a new recurrent
|
11 |
+
cell called a "forgetful recurrent unit" that calls directly into a
|
12 |
+
CUDA backend:
|
13 |
+
|
14 |
+
```python
|
15 |
+
from backend import FRUBackend
|
16 |
+
|
17 |
+
def fru(input, hidden, weight, bias):
|
18 |
+
# call to CUDA code
|
19 |
+
FRUBackend(input, hidden, weight, bias)
|
20 |
+
```
|
21 |
+
|
22 |
+
In this case, it is possible to get a runtime type mismatch. For
|
23 |
+
example, you might have `input` in fp16, and `weight` in fp32, and amp
|
24 |
+
doesn't have the visibility to insert an appropriate cast.
|
25 |
+
|
26 |
+
amp exposes two ways to handle "invisible" backend code: function
|
27 |
+
annotations and explicit registration.
|
28 |
+
|
29 |
+
#### Function annotation
|
30 |
+
|
31 |
+
The first way to handle backend code is a set of function annotations:
|
32 |
+
|
33 |
+
- `@amp.half_function`
|
34 |
+
- `@amp.float_function`
|
35 |
+
- `@amp.promote_function`
|
36 |
+
|
37 |
+
These correspond to:
|
38 |
+
|
39 |
+
- Cast all arguments to fp16
|
40 |
+
- Cast all argumnets fo fp32
|
41 |
+
- If there are any type mismatches, cast everything to the widest type
|
42 |
+
|
43 |
+
In our example, we believe that the FRU unit is fp16-safe and will get
|
44 |
+
performance gains from casting its arguments to fp16, so we write:
|
45 |
+
|
46 |
+
```python
|
47 |
+
@amp.half_function
|
48 |
+
def fru(input, hidden, weight, bias):
|
49 |
+
#...
|
50 |
+
```
|
51 |
+
|
52 |
+
#### Explicit registration
|
53 |
+
|
54 |
+
The other way to handle backend code is with explicit function
|
55 |
+
registration:
|
56 |
+
|
57 |
+
- `amp.register_half_function(module, function_name)`
|
58 |
+
- `amp.register_float_function(module, function_name)`
|
59 |
+
- `amp.register_promote_function(module, function_name)`
|
60 |
+
|
61 |
+
When using this API, `module` is the containing class or module for
|
62 |
+
the function, and `function_name` is the _string_ name of the
|
63 |
+
function. Note that the function must be registered before the call to
|
64 |
+
`amp.initalize()`.
|
65 |
+
|
66 |
+
For our FRU unit, we can register the backend function directly:
|
67 |
+
|
68 |
+
```python
|
69 |
+
import backend
|
70 |
+
|
71 |
+
amp.register_half_function(backend, 'FRUBackend')
|
72 |
+
```
|
apex/apex/amp/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .amp import init, half_function, float_function, promote_function,\
|
2 |
+
register_half_function, register_float_function, register_promote_function
|
3 |
+
from .handle import scale_loss, disable_casts
|
4 |
+
from .frontend import initialize
|
5 |
+
from ._amp_state import master_params, _amp_state
|
apex/apex/amp/__version__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
VERSION = (0, 1, 0)
|
2 |
+
__version__ = '.'.join(map(str, VERSION))
|
apex/apex/amp/_amp_state.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is a "header object" that allows different amp modules to communicate.
|
2 |
+
# I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like.
|
3 |
+
# But apparently it's ok:
|
4 |
+
# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
|
8 |
+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
9 |
+
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
10 |
+
|
11 |
+
if TORCH_MAJOR == 0:
|
12 |
+
import collections.abc as container_abcs
|
13 |
+
else:
|
14 |
+
from torch._six import container_abcs
|
15 |
+
|
16 |
+
|
17 |
+
class AmpState(object):
|
18 |
+
def __init__(self):
|
19 |
+
self.hard_override=False
|
20 |
+
self.allow_incoming_model_not_fp32 = False
|
21 |
+
self.verbosity=1
|
22 |
+
|
23 |
+
|
24 |
+
# Attribute stash. Could also just stash things as global module attributes.
|
25 |
+
_amp_state = AmpState()
|
26 |
+
|
27 |
+
|
28 |
+
def warn_or_err(msg):
|
29 |
+
if _amp_state.hard_override:
|
30 |
+
print("Warning: " + msg)
|
31 |
+
else:
|
32 |
+
raise RuntimeError(msg)
|
33 |
+
# I'm not sure if allowing hard_override is a good idea.
|
34 |
+
# + " If you're sure you know what you're doing, supply " +
|
35 |
+
# "hard_override=True to amp.initialize.")
|
36 |
+
|
37 |
+
|
38 |
+
distributed = False
|
39 |
+
if 'WORLD_SIZE' in os.environ:
|
40 |
+
distributed = int(os.environ['WORLD_SIZE']) > 1
|
41 |
+
|
42 |
+
|
43 |
+
def maybe_print(msg, rank0=False):
|
44 |
+
if _amp_state.verbosity > 0:
|
45 |
+
if rank0:
|
46 |
+
if distributed:
|
47 |
+
if torch.distributed.get_rank() == 0:
|
48 |
+
print(msg)
|
49 |
+
else:
|
50 |
+
print(msg)
|
51 |
+
else:
|
52 |
+
print(msg)
|
53 |
+
|
54 |
+
|
55 |
+
# def iter_params(param_groups):
|
56 |
+
# for group in param_groups:
|
57 |
+
# for p in group['params']:
|
58 |
+
# yield p
|
59 |
+
|
60 |
+
|
61 |
+
def master_params(optimizer):
|
62 |
+
"""
|
63 |
+
Generator expression that iterates over the params owned by ``optimizer``.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
optimizer: An optimizer previously returned from ``amp.initialize``.
|
67 |
+
"""
|
68 |
+
for group in optimizer.param_groups:
|
69 |
+
for p in group['params']:
|
70 |
+
yield p
|
apex/apex/amp/_initialize.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch._six import string_classes
|
3 |
+
import functools
|
4 |
+
import numpy as np
|
5 |
+
import warnings
|
6 |
+
from ._amp_state import _amp_state, warn_or_err, container_abcs
|
7 |
+
from .handle import disable_casts
|
8 |
+
from .scaler import LossScaler
|
9 |
+
from ._process_optimizer import _process_optimizer
|
10 |
+
from apex.fp16_utils import convert_network
|
11 |
+
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
|
12 |
+
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
|
13 |
+
from ..optimizers import FusedAdam
|
14 |
+
from ..parallel import DistributedDataParallel as apex_DDP
|
15 |
+
from ..parallel.LARC import LARC
|
16 |
+
|
17 |
+
|
18 |
+
def to_type(dtype, t):
|
19 |
+
if isinstance(t, torch.Tensor):
|
20 |
+
if not t.is_cuda:
|
21 |
+
# This should not be a hard error, since it may be legitimate.
|
22 |
+
warnings.warn("An input tensor was not cuda.")
|
23 |
+
# GANs require this.
|
24 |
+
# if t.requires_grad:
|
25 |
+
# warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
|
26 |
+
# "its gradients will not be properly allreduced by DDP.")
|
27 |
+
if t.is_floating_point():
|
28 |
+
return t.to(dtype)
|
29 |
+
return t
|
30 |
+
else:
|
31 |
+
# Trust the user's custom batch type, that's all I can do here.
|
32 |
+
return t.to(dtype)
|
33 |
+
|
34 |
+
|
35 |
+
# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
|
36 |
+
def applier(value, fn):
|
37 |
+
if isinstance(value, torch.Tensor):
|
38 |
+
return fn(value)
|
39 |
+
elif isinstance(value, string_classes):
|
40 |
+
return value
|
41 |
+
elif isinstance(value, np.ndarray):
|
42 |
+
return value
|
43 |
+
elif hasattr(value, "to"): # Allow handling of custom batch classes
|
44 |
+
return fn(value)
|
45 |
+
elif isinstance(value, container_abcs.Mapping):
|
46 |
+
return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
|
47 |
+
elif isinstance(value, container_abcs.Iterable):
|
48 |
+
return type(value)(applier(v, fn) for v in value)
|
49 |
+
else:
|
50 |
+
# Do I want this to fire off even if someone chooses to pass something ordinary like
|
51 |
+
# an int or float? May be more annoying than it's worth.
|
52 |
+
# print("Warning: unrecognized type in applier. If your input data is a custom class, "
|
53 |
+
# "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. "
|
54 |
+
# "Amp will check for your custom to() and invoke it to cast the batch's "
|
55 |
+
# "floating-point Tensors to the appropriate type. "
|
56 |
+
# "Also, if your data is a custom class, it is your responsibility to ensure that "
|
57 |
+
# "any Tensors you want to be cuda are already cuda."
|
58 |
+
return value
|
59 |
+
|
60 |
+
|
61 |
+
def check_models(models):
|
62 |
+
for model in models:
|
63 |
+
parallel_type = None
|
64 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
65 |
+
parallel_type = "torch.nn.parallel.DistributedDataParallel"
|
66 |
+
if isinstance(model, apex_DDP):
|
67 |
+
parallel_type = "apex.parallel.DistributedDataParallel"
|
68 |
+
if isinstance(model, torch.nn.parallel.DataParallel):
|
69 |
+
parallel_type = "torch.nn.parallel.DataParallel"
|
70 |
+
if parallel_type is not None:
|
71 |
+
raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) +
|
72 |
+
"Parallel wrappers should only be applied to the model(s) AFTER \n"
|
73 |
+
"the model(s) have been returned from amp.initialize.")
|
74 |
+
|
75 |
+
|
76 |
+
def check_params_fp32(models):
|
77 |
+
for model in models:
|
78 |
+
for name, param in model.named_parameters():
|
79 |
+
if param.is_floating_point():
|
80 |
+
if 'Half' in param.type():
|
81 |
+
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
|
82 |
+
"When using amp.initialize, you do not need to call .half() on your model\n"
|
83 |
+
"before passing it, no matter what optimization level you choose.".format(
|
84 |
+
name, param.type()))
|
85 |
+
elif not param.is_cuda:
|
86 |
+
warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
|
87 |
+
"When using amp.initialize, you need to provide a model with parameters\n"
|
88 |
+
"located on a CUDA device before passing it no matter what optimization level\n"
|
89 |
+
"you chose. Use model.to('cuda') to use the default device.".format(
|
90 |
+
name, param.type()))
|
91 |
+
|
92 |
+
# Backward compatibility for PyTorch 0.4
|
93 |
+
if hasattr(model, 'named_buffers'):
|
94 |
+
buf_iter = model.named_buffers()
|
95 |
+
else:
|
96 |
+
buf_iter = model._buffers
|
97 |
+
for obj in buf_iter:
|
98 |
+
if type(obj)==tuple:
|
99 |
+
name, buf = obj
|
100 |
+
else:
|
101 |
+
name, buf = obj, buf_iter[obj]
|
102 |
+
if buf.is_floating_point():
|
103 |
+
if 'Half' in buf.type():
|
104 |
+
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
|
105 |
+
"When using amp.initialize, you do not need to call .half() on your model\n"
|
106 |
+
"before passing it, no matter what optimization level you choose.".format(
|
107 |
+
name, buf.type()))
|
108 |
+
elif not buf.is_cuda:
|
109 |
+
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
|
110 |
+
"When using amp.initialize, you need to provide a model with buffers\n"
|
111 |
+
"located on a CUDA device before passing it no matter what optimization level\n"
|
112 |
+
"you chose. Use model.to('cuda') to use the default device.".format(
|
113 |
+
name, buf.type()))
|
114 |
+
|
115 |
+
|
116 |
+
def check_optimizers(optimizers):
|
117 |
+
for optim in optimizers:
|
118 |
+
bad_optim_type = None
|
119 |
+
if isinstance(optim, FP16_Optimizer_general):
|
120 |
+
bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
|
121 |
+
if isinstance(optim, FP16_Optimizer_for_fused):
|
122 |
+
bad_optim_type = "apex.optimizers.FP16_Optimizer"
|
123 |
+
if bad_optim_type is not None:
|
124 |
+
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
|
125 |
+
"The optimizer(s) passed to amp.initialize() must be bare \n"
|
126 |
+
"instances of either ordinary Pytorch optimizers, or Apex fused \n"
|
127 |
+
"optimizers (currently just FusedAdam, but FusedSGD will be added \n"
|
128 |
+
"soon). You should not manually wrap your optimizer in either \n"
|
129 |
+
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n"
|
130 |
+
"amp.initialize will take care of that for you (if necessary) based \n"
|
131 |
+
"on the specified opt_level (and optional overridden properties).")
|
132 |
+
|
133 |
+
|
134 |
+
def wrap_fused_adam(optimizer, properties):
|
135 |
+
msg = 'Currently, the usage of FusedAdam is restricted to '\
|
136 |
+
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '\
|
137 |
+
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
|
138 |
+
|
139 |
+
assert properties.master_weights is True, msg
|
140 |
+
assert properties.cast_model_type is torch.float16, msg
|
141 |
+
assert (properties.keep_batchnorm_fp32 is False or
|
142 |
+
properties.keep_batchnorm_fp32 is None), msg
|
143 |
+
|
144 |
+
if properties.loss_scale == "dynamic":
|
145 |
+
return FP16_Optimizer_for_fused(optimizer, dynamic_loss_scale=True)
|
146 |
+
else:
|
147 |
+
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
|
148 |
+
|
149 |
+
|
150 |
+
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
|
151 |
+
from apex.parallel import DistributedDataParallel as apex_DDP
|
152 |
+
from .amp import init as amp_init
|
153 |
+
|
154 |
+
optimizers_was_list = False
|
155 |
+
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
|
156 |
+
optimizers = [optimizers]
|
157 |
+
elif optimizers is None:
|
158 |
+
optimizers = []
|
159 |
+
elif isinstance(optimizers, list):
|
160 |
+
optimizers_was_list = True
|
161 |
+
check_optimizers(optimizers)
|
162 |
+
else:
|
163 |
+
check_optimizers([optimizers])
|
164 |
+
raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
|
165 |
+
|
166 |
+
if isinstance(models, torch.nn.Module):
|
167 |
+
models_was_list = False
|
168 |
+
models = [models]
|
169 |
+
elif isinstance(models, list):
|
170 |
+
models_was_list = True
|
171 |
+
else:
|
172 |
+
raise TypeError("models must be either a single model or a list of models.")
|
173 |
+
|
174 |
+
check_models(models)
|
175 |
+
|
176 |
+
if not _amp_state.allow_incoming_model_not_fp32:
|
177 |
+
check_params_fp32(models)
|
178 |
+
|
179 |
+
|
180 |
+
# In the future, when FP16_Optimizer can be deprecated and master weights can
|
181 |
+
# become an attribute, remember to stash master weights before casting the model.
|
182 |
+
|
183 |
+
if properties.cast_model_type:
|
184 |
+
if properties.keep_batchnorm_fp32:
|
185 |
+
for model in models:
|
186 |
+
convert_network(model, properties.cast_model_type)
|
187 |
+
else:
|
188 |
+
for model in models:
|
189 |
+
model.to(properties.cast_model_type)
|
190 |
+
|
191 |
+
input_caster = functools.partial(to_type, properties.cast_model_type)
|
192 |
+
if cast_model_outputs is not None:
|
193 |
+
output_caster = functools.partial(to_type, cast_model_outputs)
|
194 |
+
else:
|
195 |
+
output_caster = functools.partial(to_type, torch.float32)
|
196 |
+
|
197 |
+
for model in models:
|
198 |
+
# Patch the forward method to cast incoming data to the correct type, and
|
199 |
+
# outgoing data to float32, so "the user never needs to call .half()."
|
200 |
+
# I like writing things explicitly more than decorators.
|
201 |
+
def patch_forward(old_fwd):
|
202 |
+
def new_fwd(*args, **kwargs):
|
203 |
+
output = old_fwd(*applier(args, input_caster),
|
204 |
+
**applier(kwargs, input_caster))
|
205 |
+
return applier(output, output_caster)
|
206 |
+
return new_fwd
|
207 |
+
|
208 |
+
model.forward = patch_forward(model.forward)
|
209 |
+
|
210 |
+
# State dict trick to recast any preexisting per-param state tensors
|
211 |
+
for optimizer in optimizers:
|
212 |
+
optimizer.load_state_dict(optimizer.state_dict())
|
213 |
+
elif cast_model_outputs is not None:
|
214 |
+
output_caster = functools.partial(to_type, cast_model_outputs)
|
215 |
+
|
216 |
+
for model in models:
|
217 |
+
def patch_forward(old_fwd):
|
218 |
+
def new_fwd(*args, **kwargs):
|
219 |
+
output = old_fwd(*args, **kwargs)
|
220 |
+
return applier(output, output_caster)
|
221 |
+
return new_fwd
|
222 |
+
|
223 |
+
model.forward = patch_forward(model.forward)
|
224 |
+
|
225 |
+
for i, optimizer in enumerate(optimizers):
|
226 |
+
# Still need to special case this for the first pass
|
227 |
+
if isinstance(optimizer, FusedAdam):
|
228 |
+
optimizers[i] = wrap_fused_adam(optimizer, properties)
|
229 |
+
else:
|
230 |
+
optimizers[i] = _process_optimizer(optimizer, properties)
|
231 |
+
|
232 |
+
_amp_state.loss_scalers = []
|
233 |
+
for _ in range(num_losses):
|
234 |
+
_amp_state.loss_scalers.append(LossScaler(properties.loss_scale,
|
235 |
+
min_loss_scale=_amp_state.min_loss_scale,
|
236 |
+
max_loss_scale=_amp_state.max_loss_scale))
|
237 |
+
|
238 |
+
if properties.patch_torch_functions:
|
239 |
+
# handle is unused here. It's accessible later through a global value anyway.
|
240 |
+
handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
|
241 |
+
for optimizer in optimizers:
|
242 |
+
# Disable Amp casting for the optimizer step, because it should only be
|
243 |
+
# applied to FP32 master params anyway.
|
244 |
+
def patch_step(old_step):
|
245 |
+
def new_step(*args, **kwargs):
|
246 |
+
with disable_casts():
|
247 |
+
output = old_step(*args, **kwargs)
|
248 |
+
return output
|
249 |
+
return new_step
|
250 |
+
|
251 |
+
optimizer.step = patch_step(optimizer.step)
|
252 |
+
|
253 |
+
if optimizers_was_list:
|
254 |
+
if models_was_list:
|
255 |
+
return models, optimizers
|
256 |
+
else:
|
257 |
+
return models[0], optimizers
|
258 |
+
else:
|
259 |
+
if models_was_list:
|
260 |
+
if len(optimizers) == 0:
|
261 |
+
return models
|
262 |
+
else:
|
263 |
+
return models, optimizers[0]
|
264 |
+
else:
|
265 |
+
if len(optimizers) == 0:
|
266 |
+
return models[0]
|
267 |
+
else:
|
268 |
+
return models[0], optimizers[0]
|
apex/apex/amp/_process_optimizer.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
from ..fp16_utils import master_params_to_model_params
|
3 |
+
from ..multi_tensor_apply import multi_tensor_applier
|
4 |
+
from ._amp_state import maybe_print
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class AmpOptimizerState(object):
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
|
13 |
+
def lazy_init_with_master_weights(self):
|
14 |
+
stash = self._amp_stash
|
15 |
+
stash.fp16_groups = []
|
16 |
+
stash.fp32_from_fp16_groups = []
|
17 |
+
stash.fp32_from_fp32_groups = []
|
18 |
+
for i, param_group in enumerate(self.param_groups):
|
19 |
+
# maybe_print("FP16_Optimizer processing param group {}:".format(i))
|
20 |
+
fp16_params_this_group = []
|
21 |
+
fp32_params_this_group = []
|
22 |
+
fp32_from_fp16_params_this_group = []
|
23 |
+
for i, param in enumerate(param_group['params']):
|
24 |
+
if param.requires_grad:
|
25 |
+
if param.type() == 'torch.cuda.HalfTensor':
|
26 |
+
# maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
|
27 |
+
# .format(param.size()))
|
28 |
+
fp16_params_this_group.append(param)
|
29 |
+
master_param = param.detach().clone().float()
|
30 |
+
master_param.requires_grad = True
|
31 |
+
param_group['params'][i] = master_param
|
32 |
+
fp32_from_fp16_params_this_group.append(master_param)
|
33 |
+
# Reset existing state dict key to the new master param.
|
34 |
+
# We still need to recast per-param state tensors, if any, to FP32.
|
35 |
+
if param in self.state:
|
36 |
+
self.state[master_param] = self.state.pop(param)
|
37 |
+
elif param.type() == 'torch.cuda.FloatTensor':
|
38 |
+
# maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
|
39 |
+
# .format(param.size()))
|
40 |
+
fp32_params_this_group.append(param)
|
41 |
+
param_group['params'][i] = param
|
42 |
+
else:
|
43 |
+
raise TypeError("Optimizer's parameters must be either "
|
44 |
+
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
|
45 |
+
"Received {}".format(param.type()))
|
46 |
+
|
47 |
+
stash.fp16_groups.append(fp16_params_this_group)
|
48 |
+
stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
|
49 |
+
stash.fp32_from_fp32_groups.append(fp32_params_this_group)
|
50 |
+
|
51 |
+
stash.all_fp16_params = []
|
52 |
+
for group in stash.fp16_groups:
|
53 |
+
stash.all_fp16_params += group
|
54 |
+
|
55 |
+
stash.all_fp32_from_fp16_params = []
|
56 |
+
for group in stash.fp32_from_fp16_groups:
|
57 |
+
stash.all_fp32_from_fp16_params += group
|
58 |
+
|
59 |
+
stash.all_fp32_from_fp32_params = []
|
60 |
+
for group in stash.fp32_from_fp32_groups:
|
61 |
+
stash.all_fp32_from_fp32_params += group
|
62 |
+
|
63 |
+
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
|
64 |
+
stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
|
65 |
+
|
66 |
+
for param in stash.all_fp32_from_fp16_params:
|
67 |
+
param.grad = None
|
68 |
+
|
69 |
+
for param in stash.all_fp32_from_fp32_params:
|
70 |
+
param.grad = None
|
71 |
+
|
72 |
+
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
|
73 |
+
self.load_state_dict(self.state_dict())
|
74 |
+
|
75 |
+
|
76 |
+
def prepare_backward_with_master_weights(self):
|
77 |
+
stash = self._amp_stash
|
78 |
+
|
79 |
+
if not stash.lazy_init_called:
|
80 |
+
self._lazy_init_maybe_master_weights()
|
81 |
+
stash.lazy_init_called = True
|
82 |
+
|
83 |
+
for i, param in enumerate(stash.all_fp16_params):
|
84 |
+
# Set up to leverage grad copy elision:
|
85 |
+
param.grad = None
|
86 |
+
|
87 |
+
# for i, param in enumerate(stash.all_fp32_from_fp16_params):
|
88 |
+
# stash.all_fp32_from_fp16_grad_stash[i] = param.grad
|
89 |
+
|
90 |
+
for i, param in enumerate(stash.all_fp32_from_fp32_params):
|
91 |
+
stash.all_fp32_from_fp32_grad_stash[i] = param.grad
|
92 |
+
# Set up to leverage grad copy elision:
|
93 |
+
param.grad = None
|
94 |
+
|
95 |
+
|
96 |
+
def post_backward_with_master_weights(self, scaler):
|
97 |
+
stash = self._amp_stash
|
98 |
+
|
99 |
+
# This is a lot of python overhead...
|
100 |
+
fp16_grads_needing_unscale = []
|
101 |
+
new_fp32_grads = []
|
102 |
+
fp16_grads_needing_unscale_with_stash = []
|
103 |
+
preexisting_fp32_grads = []
|
104 |
+
for fp16_param, fp32_param in zip(stash.all_fp16_params,
|
105 |
+
stash.all_fp32_from_fp16_params):
|
106 |
+
if fp16_param.grad is None and fp32_param.grad is not None:
|
107 |
+
continue
|
108 |
+
elif fp16_param.grad is not None and fp32_param.grad is None:
|
109 |
+
fp32_param.grad = torch.empty_like(fp32_param)
|
110 |
+
fp16_grads_needing_unscale.append(fp16_param.grad)
|
111 |
+
new_fp32_grads.append(fp32_param.grad)
|
112 |
+
elif fp16_param.grad is not None and fp32_param.grad is not None:
|
113 |
+
fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)
|
114 |
+
preexisting_fp32_grads.append(fp32_param.grad)
|
115 |
+
else: # fp16_param.grad is None and fp32_param.grad is None:
|
116 |
+
continue
|
117 |
+
|
118 |
+
if len(fp16_grads_needing_unscale) > 0:
|
119 |
+
scaler.unscale(
|
120 |
+
fp16_grads_needing_unscale,
|
121 |
+
new_fp32_grads,
|
122 |
+
scaler.loss_scale(),
|
123 |
+
models_are_masters=False)
|
124 |
+
|
125 |
+
if len(fp16_grads_needing_unscale_with_stash) > 0:
|
126 |
+
scaler.unscale_with_stashed(
|
127 |
+
fp16_grads_needing_unscale_with_stash,
|
128 |
+
preexisting_fp32_grads,
|
129 |
+
preexisting_fp32_grads)
|
130 |
+
|
131 |
+
# fp32 params can be treated as they would be in the "no_master_weights" case.
|
132 |
+
grads_needing_unscale = []
|
133 |
+
grads_needing_unscale_with_stash = []
|
134 |
+
stashed = []
|
135 |
+
for param, stashed_grad in zip(stash.all_fp32_from_fp32_params,
|
136 |
+
stash.all_fp32_from_fp32_grad_stash):
|
137 |
+
if param.grad is None and stashed_grad is not None:
|
138 |
+
param.grad = stashed_grad
|
139 |
+
elif param.grad is not None and stashed_grad is None:
|
140 |
+
grads_needing_unscale.append(param.grad)
|
141 |
+
elif param.grad is not None and stashed_grad is not None:
|
142 |
+
grads_needing_unscale_with_stash.append(param.grad)
|
143 |
+
stashed.append(stashed_grad)
|
144 |
+
else: # param.grad is None and stashed_grad is None:
|
145 |
+
continue
|
146 |
+
|
147 |
+
if len(grads_needing_unscale) > 0:
|
148 |
+
scaler.unscale(
|
149 |
+
grads_needing_unscale,
|
150 |
+
grads_needing_unscale,
|
151 |
+
scaler.loss_scale(),
|
152 |
+
models_are_masters=True)
|
153 |
+
|
154 |
+
if len(grads_needing_unscale_with_stash) > 0:
|
155 |
+
scaler.unscale_with_stashed(
|
156 |
+
grads_needing_unscale_with_stash,
|
157 |
+
stashed,
|
158 |
+
grads_needing_unscale_with_stash)
|
159 |
+
|
160 |
+
# Clear the stash.
|
161 |
+
for i in range(len(stash.all_fp32_from_fp32_grad_stash)):
|
162 |
+
stash.all_fp32_from_fp32_grad_stash[i] = None
|
163 |
+
|
164 |
+
|
165 |
+
def lazy_init_no_master_weights(self):
|
166 |
+
stash = self._amp_stash
|
167 |
+
stash.all_fp16_params = []
|
168 |
+
stash.all_fp32_params = []
|
169 |
+
for i, param_group in enumerate(self.param_groups):
|
170 |
+
for i, param in enumerate(param_group['params']):
|
171 |
+
if param.type() == 'torch.cuda.HalfTensor':
|
172 |
+
stash.all_fp16_params.append(param)
|
173 |
+
elif param.type() == 'torch.cuda.FloatTensor':
|
174 |
+
stash.all_fp32_params.append(param)
|
175 |
+
else:
|
176 |
+
raise TypeError("Optimizer's parameters must be either "
|
177 |
+
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
|
178 |
+
"Received {}".format(param.type()))
|
179 |
+
|
180 |
+
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
|
181 |
+
stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
|
182 |
+
|
183 |
+
|
184 |
+
def prepare_backward_no_master_weights(self):
|
185 |
+
stash = self._amp_stash
|
186 |
+
|
187 |
+
if not stash.lazy_init_called:
|
188 |
+
self._lazy_init_maybe_master_weights()
|
189 |
+
stash.lazy_init_called = True
|
190 |
+
|
191 |
+
for i, param in enumerate(stash.all_fp16_params):
|
192 |
+
stash.all_fp16_grad_stash[i] = param.grad
|
193 |
+
# Set up to leverage grad copy elision:
|
194 |
+
param.grad = None
|
195 |
+
|
196 |
+
for i, param in enumerate(stash.all_fp32_params):
|
197 |
+
stash.all_fp32_grad_stash[i] = param.grad
|
198 |
+
# Set up to leverage grad copy elision:
|
199 |
+
param.grad = None
|
200 |
+
|
201 |
+
|
202 |
+
def post_backward_no_master_weights(self, scaler):
|
203 |
+
stash = self._amp_stash
|
204 |
+
|
205 |
+
split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
|
206 |
+
(stash.all_fp32_params, stash.all_fp32_grad_stash))
|
207 |
+
|
208 |
+
for params, stashed_grads in split_types:
|
209 |
+
# This is a lot of python overhead...
|
210 |
+
grads_needing_unscale = []
|
211 |
+
grads_needing_unscale_with_stash = []
|
212 |
+
stashed = []
|
213 |
+
for param, stashed_grad in zip(params, stashed_grads):
|
214 |
+
if param.grad is None and stashed_grad is not None:
|
215 |
+
param.grad = stashed_grad
|
216 |
+
elif param.grad is not None and stashed_grad is None:
|
217 |
+
grads_needing_unscale.append(param.grad)
|
218 |
+
elif param.grad is not None and stashed_grad is not None:
|
219 |
+
grads_needing_unscale_with_stash.append(param.grad)
|
220 |
+
stashed.append(stashed_grad)
|
221 |
+
else: # param.grad is None and stashed_grad is None
|
222 |
+
continue
|
223 |
+
|
224 |
+
if len(grads_needing_unscale) > 0:
|
225 |
+
scaler.unscale(
|
226 |
+
grads_needing_unscale,
|
227 |
+
grads_needing_unscale,
|
228 |
+
scaler.loss_scale(),
|
229 |
+
models_are_masters=True)
|
230 |
+
|
231 |
+
if len(grads_needing_unscale_with_stash) > 0:
|
232 |
+
scaler.unscale_with_stashed(
|
233 |
+
grads_needing_unscale_with_stash,
|
234 |
+
stashed,
|
235 |
+
grads_needing_unscale_with_stash)
|
236 |
+
|
237 |
+
# Clear the stash.
|
238 |
+
for i in range(len(stashed_grads)):
|
239 |
+
stashed_grads[i] = None
|
240 |
+
|
241 |
+
|
242 |
+
def _master_params_to_model_params(self):
|
243 |
+
stash = self._amp_stash
|
244 |
+
if multi_tensor_applier.available:
|
245 |
+
if len(stash.all_fp16_params) > 0:
|
246 |
+
multi_tensor_applier(
|
247 |
+
stash.multi_tensor_scale,
|
248 |
+
stash.dummy_overflow_buf,
|
249 |
+
[stash.all_fp32_from_fp16_params, stash.all_fp16_params],
|
250 |
+
1.0)
|
251 |
+
else:
|
252 |
+
for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
|
253 |
+
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
|
254 |
+
|
255 |
+
|
256 |
+
def _process_optimizer(optimizer, properties):
|
257 |
+
if hasattr(optimizer, "_amp_stash"):
|
258 |
+
raise RuntimeError("A given optimizer should only be passed through amp.initialize once.")
|
259 |
+
else:
|
260 |
+
optimizer._amp_stash = AmpOptimizerState()
|
261 |
+
|
262 |
+
optimizer._amp_stash.lazy_init_called = False
|
263 |
+
optimizer._amp_stash.already_patched = False
|
264 |
+
optimizer._amp_stash.params_have_scaled_gradients = False
|
265 |
+
|
266 |
+
for name in ("_lazy_init_maybe_master_weights",
|
267 |
+
"_master_params_to_model_params",
|
268 |
+
"_prepare_amp_backward",
|
269 |
+
"_post_amp_backward"):
|
270 |
+
if hasattr(optimizer, name):
|
271 |
+
raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
|
272 |
+
|
273 |
+
# TODO: Centralize exposure and import error checking for the C backend.
|
274 |
+
if multi_tensor_applier.available:
|
275 |
+
import amp_C
|
276 |
+
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
|
277 |
+
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
|
278 |
+
|
279 |
+
if properties.master_weights:
|
280 |
+
optimizer._lazy_init_maybe_master_weights = types.MethodType(
|
281 |
+
lazy_init_with_master_weights, optimizer)
|
282 |
+
|
283 |
+
optimizer._master_params_to_model_params = types.MethodType(
|
284 |
+
_master_params_to_model_params, optimizer)
|
285 |
+
|
286 |
+
old_step = optimizer.step
|
287 |
+
def new_step(self, closure=None):
|
288 |
+
if closure is not None:
|
289 |
+
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
|
290 |
+
retval = old_step()
|
291 |
+
self._master_params_to_model_params()
|
292 |
+
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
|
293 |
+
for param in self._amp_stash.all_fp32_from_fp16_params:
|
294 |
+
param.grad = None
|
295 |
+
return retval
|
296 |
+
optimizer.step = types.MethodType(new_step, optimizer)
|
297 |
+
|
298 |
+
old_zero_grad = optimizer.zero_grad
|
299 |
+
def new_zero_grad(self):
|
300 |
+
stash = self._amp_stash
|
301 |
+
if not stash.lazy_init_called:
|
302 |
+
self._lazy_init_maybe_master_weights()
|
303 |
+
stash.lazy_init_called = True
|
304 |
+
# Zero the model grads.
|
305 |
+
for param in stash.all_fp16_params:
|
306 |
+
if param.grad is not None:
|
307 |
+
param.grad.detach_()
|
308 |
+
param.grad.zero_()
|
309 |
+
for param in stash.all_fp32_from_fp32_params:
|
310 |
+
if param.grad is not None:
|
311 |
+
param.grad.detach_()
|
312 |
+
param.grad.zero_()
|
313 |
+
# Clear the master grads that are independent of model grads
|
314 |
+
for param in self._amp_stash.all_fp32_from_fp16_params:
|
315 |
+
param.grad = None
|
316 |
+
optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
|
317 |
+
|
318 |
+
optimizer._prepare_amp_backward = types.MethodType(
|
319 |
+
prepare_backward_with_master_weights, optimizer)
|
320 |
+
|
321 |
+
optimizer._post_amp_backward = types.MethodType(
|
322 |
+
post_backward_with_master_weights, optimizer)
|
323 |
+
else:
|
324 |
+
optimizer._lazy_init_maybe_master_weights = types.MethodType(
|
325 |
+
lazy_init_no_master_weights, optimizer)
|
326 |
+
|
327 |
+
optimizer._prepare_amp_backward = types.MethodType(
|
328 |
+
prepare_backward_no_master_weights, optimizer)
|
329 |
+
|
330 |
+
optimizer._post_amp_backward = types.MethodType(
|
331 |
+
post_backward_no_master_weights, optimizer)
|
332 |
+
|
333 |
+
old_add_param_group = optimizer.add_param_group
|
334 |
+
|
335 |
+
def new_add_param_group(self, new_group):
|
336 |
+
stash = self._amp_stash
|
337 |
+
|
338 |
+
if not stash.lazy_init_called:
|
339 |
+
self._lazy_init_maybe_master_weights()
|
340 |
+
stash.lazy_init_called = True
|
341 |
+
|
342 |
+
assert isinstance(new_group, dict), "param group must be a dict"
|
343 |
+
|
344 |
+
new_params = new_group['params']
|
345 |
+
if isinstance(new_params, torch.Tensor):
|
346 |
+
new_group['params'] = [new_params]
|
347 |
+
elif isinstance(new_params, set):
|
348 |
+
raise TypeError('optimizer parameters need to be organized in ordered collections, but '
|
349 |
+
'the ordering of tensors in sets will change between runs. Please use a list instead.')
|
350 |
+
else:
|
351 |
+
new_group['params'] = list(new_params)
|
352 |
+
|
353 |
+
if properties.master_weights:
|
354 |
+
# Mutate new_group in-place to use FP32 master params
|
355 |
+
fp16_params_this_group = []
|
356 |
+
fp32_params_this_group = []
|
357 |
+
fp32_from_fp16_params_this_group = []
|
358 |
+
for i, param in enumerate(new_group['params']):
|
359 |
+
if param.requires_grad:
|
360 |
+
if param.type() == 'torch.cuda.HalfTensor':
|
361 |
+
fp16_params_this_group.append(param)
|
362 |
+
master_param = param.detach().clone().float()
|
363 |
+
master_param.requires_grad = True
|
364 |
+
new_group['params'][i] = master_param
|
365 |
+
fp32_from_fp16_params_this_group.append(master_param)
|
366 |
+
elif param.type() == 'torch.cuda.FloatTensor':
|
367 |
+
fp32_params_this_group.append(param)
|
368 |
+
new_group['params'][i] = param
|
369 |
+
else:
|
370 |
+
raise TypeError("Optimizer's parameters must be either "
|
371 |
+
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
|
372 |
+
"Received {}".format(param.type()))
|
373 |
+
|
374 |
+
stash.fp16_groups.append(fp16_params_this_group)
|
375 |
+
stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
|
376 |
+
stash.fp32_from_fp32_groups.append(fp32_params_this_group)
|
377 |
+
|
378 |
+
stash.all_fp16_params += fp16_params_this_group
|
379 |
+
stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group
|
380 |
+
stash.all_fp32_from_fp32_params += fp32_params_this_group
|
381 |
+
|
382 |
+
# stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
|
383 |
+
stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group]
|
384 |
+
|
385 |
+
# It should be ok to let params be added with existing .grad attributes.
|
386 |
+
# for param in fp16_params_this_group:
|
387 |
+
# param.grad = None
|
388 |
+
|
389 |
+
# for param in fp32_from_fp16_params_this_group:
|
390 |
+
# param.grad = None
|
391 |
+
|
392 |
+
# for param in stash.fp32_params_this_group:
|
393 |
+
# param.grad = None
|
394 |
+
else:
|
395 |
+
for param in new_group['params']:
|
396 |
+
if param.type() == 'torch.cuda.HalfTensor':
|
397 |
+
stash.all_fp16_params.append(param)
|
398 |
+
stash.all_fp16_grad_stash.append(None)
|
399 |
+
elif param.type() == 'torch.cuda.FloatTensor':
|
400 |
+
stash.all_fp32_params.append(param)
|
401 |
+
stash.all_fp32_grad_stash.append(None)
|
402 |
+
else:
|
403 |
+
raise TypeError("Optimizer's parameters must be either "
|
404 |
+
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
|
405 |
+
"Received {}".format(param.type()))
|
406 |
+
|
407 |
+
old_add_param_group(new_group)
|
408 |
+
|
409 |
+
optimizer.add_param_group = types.MethodType(new_add_param_group, optimizer)
|
410 |
+
|
411 |
+
return optimizer
|
apex/apex/amp/amp.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import compat, rnn_compat, utils, wrap
|
2 |
+
from .handle import AmpHandle, NoOpHandle
|
3 |
+
from .lists import functional_overrides, torch_overrides, tensor_overrides
|
4 |
+
from ._amp_state import _amp_state
|
5 |
+
from .frontend import *
|
6 |
+
|
7 |
+
import functools
|
8 |
+
import itertools
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
_DECORATOR_HANDLE = None
|
14 |
+
_USER_CAST_REGISTRY = set()
|
15 |
+
_USER_PROMOTE_REGISTRY = set()
|
16 |
+
|
17 |
+
|
18 |
+
def _decorator_helper(orig_fn, cast_fn, wrap_fn):
|
19 |
+
def wrapper(*args, **kwargs):
|
20 |
+
handle = _DECORATOR_HANDLE
|
21 |
+
if handle is None or not handle.is_active():
|
22 |
+
return orig_fn(*args, **kwargs)
|
23 |
+
inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
|
24 |
+
handle.verbose)
|
25 |
+
return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
|
26 |
+
return wrapper
|
27 |
+
|
28 |
+
|
29 |
+
# Decorator form
|
30 |
+
def half_function(fn):
|
31 |
+
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
|
32 |
+
return _decorator_helper(fn, utils.maybe_half, wrap_fn)
|
33 |
+
|
34 |
+
|
35 |
+
def float_function(fn):
|
36 |
+
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
|
37 |
+
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
|
38 |
+
|
39 |
+
|
40 |
+
def promote_function(fn):
|
41 |
+
wrap_fn = functools.partial(wrap.make_promote_wrapper)
|
42 |
+
return _decorator_helper(fn, utils.maybe_float, wrap_fn)
|
43 |
+
|
44 |
+
|
45 |
+
# Registry form
|
46 |
+
def register_half_function(module, name):
|
47 |
+
if not hasattr(module, name):
|
48 |
+
raise ValueError('No function named {} in module {}.'.format(
|
49 |
+
name, module))
|
50 |
+
_USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
|
51 |
+
|
52 |
+
|
53 |
+
def register_float_function(module, name):
|
54 |
+
if not hasattr(module, name):
|
55 |
+
raise ValueError('No function named {} in module {}.'.format(
|
56 |
+
name, module))
|
57 |
+
_USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
|
58 |
+
|
59 |
+
|
60 |
+
def register_promote_function(module, name):
|
61 |
+
if not hasattr(module, name):
|
62 |
+
raise ValueError('No function named {} in module {}.'.format(
|
63 |
+
name, module))
|
64 |
+
_USER_PROMOTE_REGISTRY.add((module, name))
|
65 |
+
|
66 |
+
|
67 |
+
# Top-level function to insert _all_ the hooks.
|
68 |
+
def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False):
|
69 |
+
global _DECORATOR_HANDLE
|
70 |
+
|
71 |
+
if not enabled:
|
72 |
+
handle = NoOpHandle()
|
73 |
+
_DECORATOR_HANDLE = handle
|
74 |
+
return handle
|
75 |
+
|
76 |
+
handle = AmpHandle(loss_scale, enable_caching, verbose)
|
77 |
+
|
78 |
+
# 0) Force-{fp16, fp32} for user-annotated functions
|
79 |
+
for mod, fn, cast_fn in _USER_CAST_REGISTRY:
|
80 |
+
try_caching = (cast_fn == utils.maybe_half)
|
81 |
+
wrap.cached_cast(mod, fn, cast_fn, handle,
|
82 |
+
try_caching, verbose)
|
83 |
+
_USER_CAST_REGISTRY.clear()
|
84 |
+
|
85 |
+
# 0.5) Force-promote for user-annotated functions
|
86 |
+
for mod, fn in _USER_PROMOTE_REGISTRY:
|
87 |
+
wrap.promote(mod, fn, handle, verbose)
|
88 |
+
_USER_PROMOTE_REGISTRY.clear()
|
89 |
+
|
90 |
+
# 1) Force-{fp16, fp32} on white- / black-list functions
|
91 |
+
override_modules = [functional_overrides,
|
92 |
+
torch_overrides,
|
93 |
+
tensor_overrides]
|
94 |
+
cast_table = [('FP16_FUNCS', utils.maybe_half),
|
95 |
+
('FP32_FUNCS', utils.maybe_float)]
|
96 |
+
for module, (list_name, cast_fn) in itertools.product(override_modules,
|
97 |
+
cast_table):
|
98 |
+
for fn in getattr(module, list_name):
|
99 |
+
try_caching = (cast_fn == utils.maybe_half)
|
100 |
+
wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
|
101 |
+
try_caching, verbose)
|
102 |
+
|
103 |
+
# 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
|
104 |
+
# methods on FloatTensor, since they're distinct types.
|
105 |
+
if compat.tensor_is_float_tensor():
|
106 |
+
for fn in tensor_overrides.FP16_FUNCS:
|
107 |
+
wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,
|
108 |
+
handle, try_caching=True, verbose=verbose)
|
109 |
+
for fn in tensor_overrides.FP32_FUNCS:
|
110 |
+
wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,
|
111 |
+
handle, try_caching=False, verbose=verbose)
|
112 |
+
|
113 |
+
# 2) Enable type-promotion on multi-arg functions and methods.
|
114 |
+
# NB: special handling for sequence fns (e.g. `torch.cat`).
|
115 |
+
promote_modules = [torch_overrides, tensor_overrides]
|
116 |
+
promote_table = [('CASTS', wrap.promote),
|
117 |
+
('SEQUENCE_CASTS', wrap.sequence_promote)]
|
118 |
+
for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,
|
119 |
+
promote_table):
|
120 |
+
for fn in getattr(promote_mod, list_name):
|
121 |
+
promote_fn(promote_mod.MODULE, fn, handle, verbose)
|
122 |
+
|
123 |
+
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
|
124 |
+
if compat.tensor_is_float_tensor():
|
125 |
+
for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,
|
126 |
+
torch.cuda.HalfTensor],
|
127 |
+
promote_table):
|
128 |
+
for fn in getattr(tensor_overrides, list_name):
|
129 |
+
promote_fn(cls, fn, handle, verbose)
|
130 |
+
|
131 |
+
# 3) For any in-place version of a blacklist function, error if any input is fp16.
|
132 |
+
# NB: this is overly conservative.
|
133 |
+
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
|
134 |
+
wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
|
135 |
+
|
136 |
+
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
|
137 |
+
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
|
138 |
+
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
|
139 |
+
if compat.tensor_is_float_tensor():
|
140 |
+
wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)
|
141 |
+
|
142 |
+
# 4) For other in-place methods, match the type of self tensor
|
143 |
+
for fn in utils.as_inplace(itertools.chain(
|
144 |
+
tensor_overrides.FP16_FUNCS,
|
145 |
+
tensor_overrides.CASTS)):
|
146 |
+
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
|
147 |
+
if compat.tensor_is_float_tensor():
|
148 |
+
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
|
149 |
+
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
|
150 |
+
|
151 |
+
# 5) RNNs + RNN cells are whitelisted specially
|
152 |
+
if rnn_compat.has_old_rnns():
|
153 |
+
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)
|
154 |
+
if not rnn_compat.has_old_rnns():
|
155 |
+
# Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable.
|
156 |
+
torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
|
157 |
+
# Wrap all the rnns
|
158 |
+
for x in rnn_compat.RNN_NAMES:
|
159 |
+
wrap.new_rnn_cast(x.upper(), handle, verbose)
|
160 |
+
|
161 |
+
# Wrap all the RNN cells
|
162 |
+
rnn_compat.whitelist_rnn_cells(handle, verbose)
|
163 |
+
|
164 |
+
# 6) Place error+print message on banned functions.
|
165 |
+
# Or, if allow_banned, then cast to FP32.
|
166 |
+
for fn, err_msg in functional_overrides.BANNED_FUNCS:
|
167 |
+
if allow_banned:
|
168 |
+
wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,
|
169 |
+
handle, try_caching=True, verbose=verbose)
|
170 |
+
else:
|
171 |
+
wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
|
172 |
+
|
173 |
+
_DECORATOR_HANDLE = handle
|
174 |
+
|
175 |
+
_amp_state.handle = handle
|
176 |
+
|
177 |
+
return handle
|
apex/apex/amp/compat.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# True for post-0.4, when Variables/Tensors merged.
|
4 |
+
def variable_is_tensor():
|
5 |
+
v = torch.autograd.Variable()
|
6 |
+
return isinstance(v, torch.Tensor)
|
7 |
+
|
8 |
+
def tensor_is_variable():
|
9 |
+
x = torch.Tensor()
|
10 |
+
return type(x) == torch.autograd.Variable
|
11 |
+
|
12 |
+
# False for post-0.4
|
13 |
+
def tensor_is_float_tensor():
|
14 |
+
x = torch.Tensor()
|
15 |
+
return type(x) == torch.FloatTensor
|
16 |
+
|
17 |
+
# Akin to `torch.is_tensor`, but returns True for Variable
|
18 |
+
# objects in pre-0.4.
|
19 |
+
def is_tensor_like(x):
|
20 |
+
return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)
|
21 |
+
|
22 |
+
# Wraps `torch.is_floating_point` if present, otherwise checks
|
23 |
+
# the suffix of `x.type()`.
|
24 |
+
def is_floating_point(x):
|
25 |
+
if hasattr(torch, 'is_floating_point'):
|
26 |
+
return torch.is_floating_point(x)
|
27 |
+
try:
|
28 |
+
torch_type = x.type()
|
29 |
+
return torch_type.endswith('FloatTensor') or \
|
30 |
+
torch_type.endswith('HalfTensor') or \
|
31 |
+
torch_type.endswith('DoubleTensor')
|
32 |
+
except AttributeError:
|
33 |
+
return False
|
34 |
+
|
35 |
+
def scalar_python_val(x):
|
36 |
+
if hasattr(x, 'item'):
|
37 |
+
return x.item()
|
38 |
+
else:
|
39 |
+
if isinstance(x, torch.autograd.Variable):
|
40 |
+
return x.data[0]
|
41 |
+
else:
|
42 |
+
return x[0]
|
apex/apex/amp/frontend.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from ._initialize import _initialize
|
3 |
+
from ._amp_state import _amp_state, warn_or_err, maybe_print
|
4 |
+
|
5 |
+
|
6 |
+
class Properties(object):
|
7 |
+
"""
|
8 |
+
This class has two purposes: to establish a set of default properties,
|
9 |
+
and to route setting of these attributes through __setattr__ so that (in theory)
|
10 |
+
they can be checked for consistency with other existing args.
|
11 |
+
"""
|
12 |
+
def __init__(self):
|
13 |
+
self.options = {
|
14 |
+
"enabled" : False,
|
15 |
+
"opt_level" : None,
|
16 |
+
"cast_model_type" : None,
|
17 |
+
"patch_torch_functions" : False,
|
18 |
+
"keep_batchnorm_fp32" : None,
|
19 |
+
"master_weights" : None,
|
20 |
+
"loss_scale" : 1.0,
|
21 |
+
# Reserved for future functionality
|
22 |
+
# "fused_optimizer" : False,
|
23 |
+
# "enable_ddp_interop" : False,
|
24 |
+
}
|
25 |
+
|
26 |
+
"""
|
27 |
+
This function allows updating several options at a time without routing through
|
28 |
+
__setattr__ checks, to avoid "you can't get there from here" scenarios.
|
29 |
+
Currently not intended to be exposed; users are expected to select an opt_level
|
30 |
+
and apply consistent modifications.
|
31 |
+
"""
|
32 |
+
def _update_options_dict(new_options):
|
33 |
+
for k, v in new_options:
|
34 |
+
if k in self.options:
|
35 |
+
self.options[k] = v
|
36 |
+
else:
|
37 |
+
raise ValueError("Tried to set unexpected option {}".format(k))
|
38 |
+
"""
|
39 |
+
The members of "options" are not direct attributes of self, so access attempts
|
40 |
+
will roll down to __getattr__. This borrows from the logic in torch.nn.Module.
|
41 |
+
"""
|
42 |
+
def __getattr__(self, name):
|
43 |
+
if "options" in self.__dict__:
|
44 |
+
options = self.__dict__["options"]
|
45 |
+
if name in options:
|
46 |
+
return options[name]
|
47 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
48 |
+
type(self).__name__, name))
|
49 |
+
|
50 |
+
def __setattr__(self, name, value):
|
51 |
+
if "options" in self.__dict__:
|
52 |
+
if name in self.options:
|
53 |
+
# print("setting {} {}".format(name, value))
|
54 |
+
if name == "cast_model_type":
|
55 |
+
if self.opt_level == "O1" and value is not None:
|
56 |
+
if value is not False:
|
57 |
+
if value is not torch.float32:
|
58 |
+
warn_or_err("O1 inserts casts around Torch functions rather than "
|
59 |
+
"model weights, so with O1, the model weights themselves "
|
60 |
+
"should remain FP32. If you wish to cast the model to a "
|
61 |
+
"different type, use opt_level='O2' or 'O3'. " +
|
62 |
+
"cast_model_type was {}".format(value))
|
63 |
+
self.options[name] = value
|
64 |
+
elif name == "patch_torch_functions":
|
65 |
+
if self.opt_level != "O1" and value:
|
66 |
+
warn_or_err("Currently, patch_torch_functions=True should only be set by "
|
67 |
+
"selecting opt_level='O1'.")
|
68 |
+
self.options[name] = value
|
69 |
+
elif name == "keep_batchnorm_fp32":
|
70 |
+
if self.opt_level == "O1" and value is not None:
|
71 |
+
warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
|
72 |
+
"to run in FP32, so keep_batchnorm_fp32 should be None." +
|
73 |
+
" keep_batchnorm_fp32 was {}".format(value))
|
74 |
+
if value == "False":
|
75 |
+
self.options[name] = False
|
76 |
+
elif value == "True":
|
77 |
+
self.options[name] = True
|
78 |
+
else:
|
79 |
+
assert (value is True or value is False or value is None),\
|
80 |
+
"keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\
|
81 |
+
"or None, found keep_batchnorm_fp32={}".format(value)
|
82 |
+
self.options[name] = value
|
83 |
+
elif name == "master_weights":
|
84 |
+
if self.opt_level == "O1" and value is not None:
|
85 |
+
warn_or_err("It doesn't make sense to use master_weights with O1. "
|
86 |
+
"With O1, your model weights themselves should be FP32.")
|
87 |
+
self.options[name] = value
|
88 |
+
elif name == "loss_scale":
|
89 |
+
if value == "dynamic":
|
90 |
+
self.options[name] = value
|
91 |
+
else:
|
92 |
+
self.options[name] = float(value)
|
93 |
+
else:
|
94 |
+
self.options[name] = value
|
95 |
+
else:
|
96 |
+
super(Properties, self).__setattr__(name, value)
|
97 |
+
|
98 |
+
|
99 |
+
""" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """
|
100 |
+
|
101 |
+
class O3:
|
102 |
+
brief = "O3: Pure FP16 training."
|
103 |
+
more = "Calls .half() on your model, converting the entire model to FP16.\n"\
|
104 |
+
"A casting operation is also inserted to cast incoming Tensors to FP16,\n"\
|
105 |
+
"so you don't need to change your data pipeline.\n"\
|
106 |
+
"This mode is useful for establishing a performance ceiling.\n"\
|
107 |
+
"It's also possible training may 'just work' in this mode.\n"\
|
108 |
+
"If not, try other optimization levels."
|
109 |
+
|
110 |
+
def __call__(self, properties):
|
111 |
+
properties.enabled = True
|
112 |
+
properties.opt_level = "O3"
|
113 |
+
properties.cast_model_type = torch.float16
|
114 |
+
properties.patch_torch_functions = False
|
115 |
+
properties.keep_batchnorm_fp32 = False
|
116 |
+
properties.master_weights = False
|
117 |
+
properties.loss_scale = 1.0
|
118 |
+
# properties.fused_optimizer = False
|
119 |
+
# properties.enable_ddp_interop = False
|
120 |
+
return properties # modified in place so this isn't really necessary
|
121 |
+
|
122 |
+
|
123 |
+
class O2:
|
124 |
+
brief = "O2: FP16 training with FP32 batchnorm and FP32 master weights.\n"
|
125 |
+
more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\
|
126 |
+
"to FP16. Batchnorms are retained in FP32 for additional stability.\n"\
|
127 |
+
"The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\
|
128 |
+
"your data pipeline.\n"\
|
129 |
+
"O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\
|
130 |
+
"these master weights, then copy the master weights into the FP16 model weights.\n"\
|
131 |
+
"Master weights can also improve convergence and stability."
|
132 |
+
|
133 |
+
def __call__(self, properties):
|
134 |
+
properties.enabled = True
|
135 |
+
properties.opt_level = "O2"
|
136 |
+
properties.cast_model_type = torch.float16
|
137 |
+
properties.patch_torch_functions = False
|
138 |
+
properties.keep_batchnorm_fp32 = True
|
139 |
+
properties.master_weights = True
|
140 |
+
properties.loss_scale = "dynamic"
|
141 |
+
# properties.fused_optimizer = False
|
142 |
+
# properties.enable_ddp_interop = False
|
143 |
+
return properties # modified in place so this isn't really necessary
|
144 |
+
|
145 |
+
|
146 |
+
class O1:
|
147 |
+
brief = "O1: Insert automatic casts around Pytorch functions and Tensor methods.\n"
|
148 |
+
more = "The type of your model's weights is not altered. However, internally,\n"\
|
149 |
+
"Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\
|
150 |
+
"while operations that might benefit from the additional stability of FP32 are patched\n"\
|
151 |
+
"to cast their inputs to fp32.\n"\
|
152 |
+
"O1 is the safest way to try mixed precision training, and is recommended when\n"\
|
153 |
+
"trying mixed precision training for the first time."
|
154 |
+
|
155 |
+
def __call__(self, properties):
|
156 |
+
properties.enabled = True
|
157 |
+
properties.opt_level = "O1"
|
158 |
+
properties.cast_model_type = None
|
159 |
+
properties.patch_torch_functions = True
|
160 |
+
properties.keep_batchnorm_fp32 = None
|
161 |
+
properties.master_weights = None
|
162 |
+
properties.loss_scale = "dynamic"
|
163 |
+
# properties.fused_optimizer = False
|
164 |
+
# properties.enable_ddp_interop = False
|
165 |
+
return properties # modified in place so this isn't really necessary
|
166 |
+
|
167 |
+
|
168 |
+
class O0:
|
169 |
+
brief = "O0: Pure FP32 training.\n"
|
170 |
+
more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
|
171 |
+
"types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
|
172 |
+
"FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n"
|
173 |
+
|
174 |
+
def __call__(self, properties):
|
175 |
+
properties.enabled = True
|
176 |
+
properties.opt_level = "O0"
|
177 |
+
properties.cast_model_type = torch.float32
|
178 |
+
properties.patch_torch_functions = False
|
179 |
+
properties.keep_batchnorm_fp32 = None
|
180 |
+
properties.master_weights = False
|
181 |
+
properties.loss_scale = 1.0
|
182 |
+
# properties.fused_optimizer = False
|
183 |
+
# properties.enable_ddp_interop = False
|
184 |
+
return properties # modified in place so this isn't really necessary
|
185 |
+
|
186 |
+
|
187 |
+
opt_levels = {"O3": O3(),
|
188 |
+
"O2": O2(),
|
189 |
+
"O1": O1(),
|
190 |
+
"O0": O0()}
|
191 |
+
|
192 |
+
|
193 |
+
# allow user to directly pass Properties struct as well?
|
194 |
+
def initialize(
|
195 |
+
models,
|
196 |
+
optimizers=None,
|
197 |
+
enabled=True,
|
198 |
+
opt_level="O1",
|
199 |
+
cast_model_type=None,
|
200 |
+
patch_torch_functions=None,
|
201 |
+
keep_batchnorm_fp32=None,
|
202 |
+
master_weights=None,
|
203 |
+
loss_scale=None,
|
204 |
+
cast_model_outputs=None,
|
205 |
+
num_losses=1,
|
206 |
+
verbosity=1,
|
207 |
+
min_loss_scale=None,
|
208 |
+
max_loss_scale=2.**24
|
209 |
+
):
|
210 |
+
"""
|
211 |
+
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
|
212 |
+
chosen ``opt_level`` and overridden properties, if any.
|
213 |
+
|
214 |
+
``amp.initialize`` should be called **after** you have finished
|
215 |
+
constructing your model(s) and
|
216 |
+
optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.
|
217 |
+
See `Distributed training`_ in the Imagenet example.
|
218 |
+
|
219 |
+
Currently, ``amp.initialize`` should only be called **once**,
|
220 |
+
although it can process an arbitrary number of
|
221 |
+
models and optimizers (see the corresponding `Advanced Amp Usage topic`_).
|
222 |
+
If you think your use case requires ``amp.initialize`` to be called more than once,
|
223 |
+
`let us know`_.
|
224 |
+
|
225 |
+
Any property keyword argument that is not ``None`` will be interpreted as a manual override.
|
226 |
+
|
227 |
+
To prevent having to rewrite anything else in your script, name the returned models/optimizers
|
228 |
+
to replace the passed models/optimizers, as in the code sample below.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
|
232 |
+
optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast.
|
233 |
+
REQUIRED for training, optional for inference.
|
234 |
+
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
|
235 |
+
should run as if Amp were not present.
|
236 |
+
opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
|
237 |
+
"O0", "O1", "O2", and "O3", explained in detail above.
|
238 |
+
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
|
239 |
+
above.
|
240 |
+
patch_torch_functions (bool, optional, default=None): Optional property override.
|
241 |
+
keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
|
242 |
+
passed as a string, must be the string "True" or "False".
|
243 |
+
master_weights (bool, optional, default=None): Optional property override.
|
244 |
+
loss_scale (float or str, optional, default=None): Optional property override. If passed as a string,
|
245 |
+
must be a string representing a number, e.g., "128.0", or the string "dynamic".
|
246 |
+
cast_model_outputs (torch.dtype, optional, default=None): Option to ensure that the outputs
|
247 |
+
of your model(s) are always cast to a particular type regardless of ``opt_level``.
|
248 |
+
num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward
|
249 |
+
passes you plan to use. When used in conjunction with the ``loss_id`` argument to
|
250 |
+
``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,
|
251 |
+
which can improve stability. See "Multiple models/optimizers/losses"
|
252 |
+
under `Advanced Amp Usage`_ for examples. If ``num_losses`` is left to 1, Amp will still
|
253 |
+
support multiple losses/backward passes, but use a single global loss scale
|
254 |
+
for all of them.
|
255 |
+
verbosity (int, default=1): Set to 0 to suppress Amp-related output.
|
256 |
+
min_loss_scale (float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic
|
257 |
+
loss scaling. The default value of None means that no floor is imposed.
|
258 |
+
If dynamic loss scaling is not used, `min_loss_scale` is ignored.
|
259 |
+
max_loss_scale (float, default=2.**24): Sets a ceiling for the loss scale values that can be chosen by
|
260 |
+
dynamic loss scaling. If dynamic loss scaling is not used, `max_loss_scale` is ignored.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
Model(s) and optimizer(s) modified according to the ``opt_level``.
|
264 |
+
If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will
|
265 |
+
also be a list.
|
266 |
+
|
267 |
+
Permissible invocations::
|
268 |
+
|
269 |
+
model, optim = amp.initialize(model, optim,...)
|
270 |
+
model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)
|
271 |
+
[model1, model2], optim = amp.initialize([model1, model2], optim,...)
|
272 |
+
[model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)
|
273 |
+
|
274 |
+
# This is not an exhaustive list of the cross product of options that are possible,
|
275 |
+
# just a set of examples.
|
276 |
+
model, optim = amp.initialize(model, optim, opt_level="O0")
|
277 |
+
model, optim = amp.initialize(model, optim, opt_level="O0", loss_scale="dynamic"|128.0|"128.0")
|
278 |
+
|
279 |
+
model, optim = amp.initialize(model, optim, opt_level="O1") # uses "loss_scale="dynamic" default
|
280 |
+
model, optim = amp.initialize(model, optim, opt_level="O1", loss_scale=128.0|"128.0")
|
281 |
+
|
282 |
+
model, optim = amp.initialize(model, optim, opt_level="O2") # uses "loss_scale="dynamic" default
|
283 |
+
model, optim = amp.initialize(model, optim, opt_level="O2", loss_scale=128.0|"128.0")
|
284 |
+
model, optim = amp.initialize(model, optim, opt_level="O2", keep_batchnorm_fp32=True|False|"True"|"False")
|
285 |
+
|
286 |
+
model, optim = amp.initialize(model, optim, opt_level="O3") # uses loss_scale=1.0 default
|
287 |
+
model, optim = amp.initialize(model, optim, opt_level="O3", loss_scale="dynamic"|128.0|"128.0")
|
288 |
+
model, optim = amp.initialize(model, optim, opt_level="O3", keep_batchnorm_fp32=True|False|"True"|"False")
|
289 |
+
|
290 |
+
The `Imagenet example`_ demonstrates live use of various opt_levels and overrides.
|
291 |
+
|
292 |
+
.. _`Distributed training`:
|
293 |
+
https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training
|
294 |
+
|
295 |
+
.. _`Imagenet example`:
|
296 |
+
https://github.com/NVIDIA/apex/tree/master/examples/imagenet
|
297 |
+
|
298 |
+
.. _`Advanced Amp Usage`:
|
299 |
+
https://nvidia.github.io/apex/advanced.html
|
300 |
+
|
301 |
+
.. _`Advanced Amp Usage topic`:
|
302 |
+
https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses
|
303 |
+
|
304 |
+
.. _`let us know`:
|
305 |
+
https://github.com/NVIDIA/apex/issues
|
306 |
+
"""
|
307 |
+
_amp_state.opt_properties = Properties()
|
308 |
+
_amp_state.verbosity = verbosity
|
309 |
+
|
310 |
+
if not enabled:
|
311 |
+
if optimizers is None:
|
312 |
+
return models
|
313 |
+
else:
|
314 |
+
return models, optimizers
|
315 |
+
|
316 |
+
if not torch.backends.cudnn.enabled:
|
317 |
+
raise RuntimeError(
|
318 |
+
"Amp requires torch.backends.cudnn.enabled = True")
|
319 |
+
|
320 |
+
if opt_level not in opt_levels:
|
321 |
+
raise RuntimeError(
|
322 |
+
"Unexpected optimization level {}. ".format(opt_level) +
|
323 |
+
"Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
|
324 |
+
"not the number zero.")
|
325 |
+
else:
|
326 |
+
_amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
|
327 |
+
maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
|
328 |
+
maybe_print("Defaults for this optimization level are:", True)
|
329 |
+
for k, v in _amp_state.opt_properties.options.items():
|
330 |
+
maybe_print("{:22} : {}".format(k, v), True)
|
331 |
+
|
332 |
+
_amp_state.min_loss_scale = min_loss_scale
|
333 |
+
_amp_state.max_loss_scale = max_loss_scale
|
334 |
+
|
335 |
+
maybe_print("Processing user overrides (additional kwargs that are not None)...", True)
|
336 |
+
# I chose to have the keyword arguments listed directly in the argument list,
|
337 |
+
# instead of **kwargs, so I can't use kwargs.items() here.
|
338 |
+
if enabled is not None:
|
339 |
+
_amp_state.opt_properties.enabled = enabled
|
340 |
+
if opt_level is not None:
|
341 |
+
_amp_state.opt_properties.opt_level = opt_level
|
342 |
+
if cast_model_type is not None:
|
343 |
+
_amp_state.opt_properties.cast_model_type = cast_model_type
|
344 |
+
if patch_torch_functions is not None:
|
345 |
+
_amp_state.opt_properties.patch_torch_functions = patch_torch_functions
|
346 |
+
if keep_batchnorm_fp32 is not None:
|
347 |
+
_amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
|
348 |
+
if master_weights is not None:
|
349 |
+
_amp_state.opt_properties.master_weights = master_weights
|
350 |
+
if loss_scale is not None:
|
351 |
+
_amp_state.opt_properties.loss_scale = loss_scale
|
352 |
+
|
353 |
+
maybe_print("After processing overrides, optimization options are:", True)
|
354 |
+
for k, v in _amp_state.opt_properties.options.items():
|
355 |
+
maybe_print("{:22} : {}".format(k, v), True)
|
356 |
+
|
357 |
+
return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
|
358 |
+
|
359 |
+
|
360 |
+
# TODO: is this necessary/useful?
|
361 |
+
# def check_option_consistency(enabled=True,
|
362 |
+
# opt_level=None,
|
363 |
+
# cast_model_type=None,
|
364 |
+
# patch_torch_functions=None,
|
365 |
+
# keep_batchnorm_fp32=None,
|
366 |
+
# master_weights=None,
|
367 |
+
# loss_scale=None,
|
368 |
+
# enable_ddp_interop=None,
|
369 |
+
# hard_override=False):
|
370 |
+
# """
|
371 |
+
# Utility function that enables users to quickly check if the option combination they intend
|
372 |
+
# to use is permitted. ``check_option_consistency`` does not require models or optimizers
|
373 |
+
# to be constructed, and can be called at any point in the script. ``check_option_consistency``
|
374 |
+
# is totally self-contained; it does not set any amp global state or affect anything outside
|
375 |
+
# of itself.
|
376 |
+
# """
|
377 |
+
#
|
378 |
+
# if not enabled:
|
379 |
+
# return
|
380 |
+
#
|
381 |
+
# if opt_level not in opt_levels:
|
382 |
+
# raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
|
383 |
+
# else:
|
384 |
+
# opt_properties = opt_levels[opt_level](Properties())
|
385 |
+
# print("Selected optimization level {}", opt_levels[opt_level].brief)
|
386 |
+
# print("Defaults for this optimization level are:")
|
387 |
+
# for k, v in opt_properties.options:
|
388 |
+
# print("{:22} : {}".format(k, v))
|
389 |
+
#
|
390 |
+
# print("Processing user overrides (additional kwargs that are not None)...")
|
391 |
+
# for k, v in kwargs:
|
392 |
+
# if k not in _amp_state.opt_properties.options:
|
393 |
+
# raise RuntimeError("Unexpected kwarg {}".format(k))
|
394 |
+
# if v is not None:
|
395 |
+
# setattr(opt_properties, k, v)
|
396 |
+
#
|
397 |
+
# print("After processing overrides, optimization options are:")
|
398 |
+
# for k, v in opt_properties.options:
|
399 |
+
# print("{:22} : {}".format(k, v))
|
apex/apex/amp/handle.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import warnings
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from . import utils
|
6 |
+
from .opt import OptimWrapper
|
7 |
+
from .scaler import LossScaler
|
8 |
+
from ._amp_state import _amp_state, master_params, maybe_print
|
9 |
+
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
|
10 |
+
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
|
11 |
+
from ..parallel.LARC import LARC
|
12 |
+
|
13 |
+
|
14 |
+
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
|
15 |
+
@contextlib.contextmanager
|
16 |
+
def scale_loss(loss,
|
17 |
+
optimizers,
|
18 |
+
loss_id=0,
|
19 |
+
model=None,
|
20 |
+
delay_unscale=False,
|
21 |
+
delay_overflow_check=False):
|
22 |
+
"""
|
23 |
+
On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
|
24 |
+
``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
|
25 |
+
|
26 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
27 |
+
scaled_loss.backward()
|
28 |
+
|
29 |
+
On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs
|
30 |
+
and unscaled, so that ``optimizer.step()`` can be called.
|
31 |
+
|
32 |
+
.. note::
|
33 |
+
If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and
|
34 |
+
can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)
|
35 |
+
any FP16 gradients are copied to FP32 master gradients before being unscaled.
|
36 |
+
``optimizer.step()`` will then apply the unscaled master gradients to the master params.
|
37 |
+
|
38 |
+
.. warning::
|
39 |
+
If Amp is using explicit FP32 master params, only the FP32 master gradients will be
|
40 |
+
unscaled. The direct ``.grad`` attributes of any FP16
|
41 |
+
model params will remain scaled after context manager exit.
|
42 |
+
This subtlety affects gradient clipping. See "Gradient clipping" under
|
43 |
+
`Advanced Amp Usage`_ for best practices.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context
|
47 |
+
manager yields is simply ``loss.float()*loss_scale``, so in principle
|
48 |
+
``loss`` could have more than one element, as long as you call
|
49 |
+
``backward()`` on ``scaled_loss`` appropriately within the context manager body.
|
50 |
+
optimizers: All optimizer(s) for which the current backward pass is creating gradients.
|
51 |
+
Must be an optimizer or list of optimizers returned from an earlier call
|
52 |
+
to ``amp.initialize``. For example use with multiple optimizers, see
|
53 |
+
"Multiple models/optimizers/losses" under `Advanced Amp Usage`_.
|
54 |
+
loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument
|
55 |
+
to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id``
|
56 |
+
must be an integer between 0 and ``num_losses`` that tells Amp which loss is
|
57 |
+
being used for the current backward pass. See "Multiple models/optimizers/losses"
|
58 |
+
under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp
|
59 |
+
will use the default global loss scaler for this backward pass.
|
60 |
+
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
|
61 |
+
optimizations.
|
62 |
+
delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary, and
|
63 |
+
the default value of ``False`` is strongly recommended.
|
64 |
+
If ``True``, Amp will not unscale the gradients or perform model->master
|
65 |
+
gradient copies on context manager exit.
|
66 |
+
``delay_unscale=True`` is a minor ninja performance optimization and can result
|
67 |
+
in weird gotchas (especially with multiple models/optimizers/losses),
|
68 |
+
so only use it if you know what you're doing.
|
69 |
+
"Gradient accumulation across iterations" under `Advanced Amp Usage`_
|
70 |
+
illustrates a situation where this CAN (but does not need to) be used.
|
71 |
+
|
72 |
+
.. warning::
|
73 |
+
If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be
|
74 |
+
called yet after context manager exit, and must wait for another, later backward context
|
75 |
+
manager invocation with ``delay_unscale`` left to False.
|
76 |
+
|
77 |
+
.. _`Advanced Amp Usage`:
|
78 |
+
https://nvidia.github.io/apex/advanced.html
|
79 |
+
"""
|
80 |
+
if not hasattr(_amp_state, "opt_properties"):
|
81 |
+
raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized. "
|
82 |
+
"model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called "
|
83 |
+
"before `with amp.scale_loss`.")
|
84 |
+
|
85 |
+
if not _amp_state.opt_properties.enabled:
|
86 |
+
yield loss
|
87 |
+
return
|
88 |
+
|
89 |
+
if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
|
90 |
+
optimizers = [optimizers]
|
91 |
+
|
92 |
+
# this is what happens when i have to support tools from different sources under the same API...
|
93 |
+
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
|
94 |
+
if isinstance(optimizers, FP16_Optimizer_for_fused):
|
95 |
+
loss_scale = optimizers.cur_scale
|
96 |
+
else:
|
97 |
+
loss_scaler = _amp_state.loss_scalers[loss_id]
|
98 |
+
loss_scale = loss_scaler.loss_scale()
|
99 |
+
|
100 |
+
if ((not _amp_state.opt_properties.master_weights)
|
101 |
+
and (not loss_scaler.dynamic)
|
102 |
+
and loss_scale == 1.0):
|
103 |
+
yield loss.float()
|
104 |
+
# Needing to drop the cache here as well is an ugly gotcha.
|
105 |
+
# But for now I think it's necessary to short-circuit.
|
106 |
+
# Probably ok to skip this if not delay_unscale
|
107 |
+
if _amp_state.opt_properties.patch_torch_functions:
|
108 |
+
_amp_state.handle._clear_cache()
|
109 |
+
return
|
110 |
+
|
111 |
+
if not delay_unscale:
|
112 |
+
if isinstance(optimizers, list):
|
113 |
+
for optimizer in optimizers:
|
114 |
+
if not optimizer._amp_stash.params_have_scaled_gradients:
|
115 |
+
optimizer._prepare_amp_backward()
|
116 |
+
|
117 |
+
yield (loss.float())*loss_scale
|
118 |
+
|
119 |
+
if delay_unscale:
|
120 |
+
for optimizer in optimizers:
|
121 |
+
optimizer._amp_stash.params_have_scaled_gradients = True
|
122 |
+
else:
|
123 |
+
# FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
|
124 |
+
if not isinstance(optimizers, FP16_Optimizer_for_fused):
|
125 |
+
loss_scaler.clear_overflow_state()
|
126 |
+
for optimizer in optimizers:
|
127 |
+
optimizer._post_amp_backward(loss_scaler)
|
128 |
+
optimizer._amp_stash.params_have_scaled_gradients = False
|
129 |
+
# For future fused optimizers that enable sync-free dynamic loss scaling,
|
130 |
+
# should_skip will always be False.
|
131 |
+
should_skip = False if delay_overflow_check else loss_scaler.update_scale()
|
132 |
+
if should_skip:
|
133 |
+
for optimizer in optimizers:
|
134 |
+
if not optimizer._amp_stash.already_patched:
|
135 |
+
# Close on loss_scaler and loss_id as well, to be safe. Probably not
|
136 |
+
# necessary because amp.scale_loss is already creating a temporary scope.
|
137 |
+
def patch_step(opt, loss_scaler, loss_id):
|
138 |
+
opt_step = opt.step
|
139 |
+
def skip_step(closure=None):
|
140 |
+
if closure is not None:
|
141 |
+
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
|
142 |
+
maybe_print(("Gradient overflow. Skipping step, loss scaler " +
|
143 |
+
"{} reducing loss scale to {}").format(loss_id,
|
144 |
+
loss_scaler.loss_scale()))
|
145 |
+
if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
|
146 |
+
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
|
147 |
+
for param in opt._amp_stash.all_fp32_from_fp16_params:
|
148 |
+
param.grad = None
|
149 |
+
opt.step = opt_step
|
150 |
+
opt._amp_stash.already_patched = False
|
151 |
+
return skip_step
|
152 |
+
optimizer.step = patch_step(optimizer, loss_scaler, loss_id)
|
153 |
+
optimizer._amp_stash.already_patched = True
|
154 |
+
|
155 |
+
# Probably ok to skip this if not delay_unscale
|
156 |
+
if _amp_state.opt_properties.patch_torch_functions:
|
157 |
+
_amp_state.handle._clear_cache()
|
158 |
+
|
159 |
+
|
160 |
+
# Free function version of AmpHandle.disable_casts, another step on the
|
161 |
+
# path to removing the concept of "AmpHandle"
|
162 |
+
@contextlib.contextmanager
|
163 |
+
def disable_casts():
|
164 |
+
_amp_state.handle._is_active = False
|
165 |
+
yield
|
166 |
+
_amp_state.handle._is_active = True
|
167 |
+
|
168 |
+
|
169 |
+
class AmpHandle(object):
|
170 |
+
def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False):
|
171 |
+
self._enable_caching = enable_caching
|
172 |
+
self._verbose = verbose
|
173 |
+
self._cache = dict()
|
174 |
+
self._default_scaler = LossScaler(loss_scale)
|
175 |
+
self._is_active = True
|
176 |
+
self._all_wrappers = []
|
177 |
+
|
178 |
+
def is_active(self):
|
179 |
+
return self._is_active
|
180 |
+
|
181 |
+
@contextlib.contextmanager
|
182 |
+
def _disable_casts(self):
|
183 |
+
self._is_active = False
|
184 |
+
yield
|
185 |
+
self._is_active = True
|
186 |
+
|
187 |
+
def wrap_optimizer(self, optimizer, num_loss=1):
|
188 |
+
self._default_scaler = None
|
189 |
+
return OptimWrapper(optimizer, self, num_loss)
|
190 |
+
|
191 |
+
@contextlib.contextmanager
|
192 |
+
def scale_loss(self, loss, optimizer):
|
193 |
+
raise RuntimeError("The old Amp API is no longer supported. Please move to the new API, "
|
194 |
+
"documented here: https://nvidia.github.io/apex/amp.html. Transition guide: "
|
195 |
+
"https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users")
|
196 |
+
|
197 |
+
if not self.is_active():
|
198 |
+
yield loss
|
199 |
+
return
|
200 |
+
|
201 |
+
if self._default_scaler is None:
|
202 |
+
raise RuntimeError(
|
203 |
+
'After calling `handle.wrap_optimizer()`, you must explicitly ' +
|
204 |
+
'use `optimizer.scale_loss(loss)`.')
|
205 |
+
|
206 |
+
# TODO: this code block is duplicated here and `opt.py`. Unify.
|
207 |
+
loss_scale = self._default_scaler.loss_scale()
|
208 |
+
yield loss * loss_scale
|
209 |
+
|
210 |
+
self._default_scaler.clear_overflow_state()
|
211 |
+
self._default_scaler.unscale(
|
212 |
+
master_params(optimizer),
|
213 |
+
master_params(optimizer),
|
214 |
+
loss_scale)
|
215 |
+
should_skip = self._default_scaler.update_scale()
|
216 |
+
if should_skip:
|
217 |
+
optimizer_step = optimizer.step
|
218 |
+
def skip_step():
|
219 |
+
maybe_print('Gradient overflow, skipping update')
|
220 |
+
optimizer.step = optimizer_step
|
221 |
+
optimizer.step = skip_step
|
222 |
+
|
223 |
+
self._clear_cache()
|
224 |
+
|
225 |
+
def _clear_cache(self):
|
226 |
+
self._cache.clear()
|
227 |
+
|
228 |
+
# Experimental support for saving / restoring uncasted versions of functions
|
229 |
+
def _save_func(self, mod, fn, func):
|
230 |
+
self._all_wrappers.append((mod, fn, func))
|
231 |
+
|
232 |
+
def _deactivate(self):
|
233 |
+
for mod, fn, func in self._all_wrappers:
|
234 |
+
utils.set_func(mod, fn, func)
|
235 |
+
self._all_wrappers = []
|
236 |
+
|
237 |
+
@property
|
238 |
+
def has_cache(self):
|
239 |
+
return self._enable_caching
|
240 |
+
|
241 |
+
@property
|
242 |
+
def cache(self):
|
243 |
+
return self._cache
|
244 |
+
|
245 |
+
def remove_cache(self, param):
|
246 |
+
if self.has_cache and param in self.cache:
|
247 |
+
del self.cache[param]
|
248 |
+
|
249 |
+
@property
|
250 |
+
def verbose(self):
|
251 |
+
return self._verbose
|
252 |
+
|
253 |
+
class NoOpHandle(object):
|
254 |
+
def is_active(self):
|
255 |
+
return False
|
256 |
+
|
257 |
+
@contextlib.contextmanager
|
258 |
+
def _disable_casts(self):
|
259 |
+
yield
|
260 |
+
|
261 |
+
def wrap_optimizer(self, optimizer, num_loss=1):
|
262 |
+
return OptimWrapper(optimizer, self, num_loss)
|
263 |
+
|
264 |
+
@contextlib.contextmanager
|
265 |
+
def scale_loss(self, loss, optimizer):
|
266 |
+
yield loss
|
267 |
+
|
268 |
+
@property
|
269 |
+
def has_cache(self):
|
270 |
+
return False
|
271 |
+
|
272 |
+
@property
|
273 |
+
def verbose(self):
|
274 |
+
return False
|
275 |
+
|
276 |
+
def _clear_cache(self):
|
277 |
+
pass
|
278 |
+
|
279 |
+
def _deactivate(self):
|
280 |
+
pass
|
apex/apex/amp/lists/__init__.py
ADDED
File without changes
|
apex/apex/amp/lists/functional_overrides.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# TODO: think about the following two. They do weird things.
|
3 |
+
# - torch.nn.utils.clip_grad (but it should always be fp32 anyway)
|
4 |
+
# - torch.nn.utils.weight_norm
|
5 |
+
|
6 |
+
# Notes:
|
7 |
+
# F.instance_norm uses batch_norm internally. Which correctly handles
|
8 |
+
# fp16 in/out with fp32 weights. So we shouldn't do anything for
|
9 |
+
# either of these.
|
10 |
+
# F.normalize calls `input.norm()` internally, so it's redundant, but
|
11 |
+
# kept here in case impl. changes.
|
12 |
+
# F.cosine_similarity is same: calls `x.norm()` internally.
|
13 |
+
|
14 |
+
import torch.nn.functional
|
15 |
+
|
16 |
+
MODULE = torch.nn.functional
|
17 |
+
|
18 |
+
FP16_FUNCS = [
|
19 |
+
'conv1d',
|
20 |
+
'conv2d',
|
21 |
+
'conv3d',
|
22 |
+
'conv_transpose1d',
|
23 |
+
'conv_transpose2d',
|
24 |
+
'conv_transpose3d',
|
25 |
+
'conv_tbc', # Undocumented / maybe new?
|
26 |
+
'linear',
|
27 |
+
]
|
28 |
+
|
29 |
+
FP32_FUNCS = [
|
30 |
+
|
31 |
+
# Interpolation/Upsampling
|
32 |
+
'interpolate',
|
33 |
+
|
34 |
+
# Pointwise
|
35 |
+
'softplus',
|
36 |
+
'softmin',
|
37 |
+
'log_softmax',
|
38 |
+
'softmax',
|
39 |
+
|
40 |
+
# Normalization
|
41 |
+
'layer_norm',
|
42 |
+
'group_norm',
|
43 |
+
'local_response_norm',
|
44 |
+
'normalize',
|
45 |
+
'cosine_similarity',
|
46 |
+
|
47 |
+
# Loss functions
|
48 |
+
# TODO: which of these can be fp16?
|
49 |
+
'poisson_nll_loss',
|
50 |
+
'cosine_embedding_loss',
|
51 |
+
'cross_entropy',
|
52 |
+
'hinge_embedding_loss',
|
53 |
+
'kl_div',
|
54 |
+
'l1_loss',
|
55 |
+
'mse_loss',
|
56 |
+
'margin_ranking_loss',
|
57 |
+
'multilabel_margin_loss',
|
58 |
+
'multilabel_soft_margin_loss',
|
59 |
+
'multi_margin_loss',
|
60 |
+
'nll_loss',
|
61 |
+
'binary_cross_entropy_with_logits',
|
62 |
+
'smooth_l1_loss',
|
63 |
+
'soft_margin_loss',
|
64 |
+
'triplet_margin_loss'
|
65 |
+
]
|
66 |
+
|
67 |
+
BANNED_FUNCS = [
|
68 |
+
('binary_cross_entropy',
|
69 |
+
("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
|
70 |
+
"It requires that the output of the previous function be already a FloatTensor. \n\n"
|
71 |
+
"Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
|
72 |
+
" torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
|
73 |
+
"that is compatible with amp.\nAnother option is to add\n"
|
74 |
+
" amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
|
75 |
+
"If you _really_ know what you are doing, you can disable this warning by passing "
|
76 |
+
"allow_banned=True to `amp.init()`."))
|
77 |
+
]
|
apex/apex/amp/lists/tensor_overrides.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .. import compat
|
2 |
+
from . import torch_overrides
|
3 |
+
|
4 |
+
import importlib
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
if compat.variable_is_tensor() and not compat.tensor_is_variable():
|
9 |
+
MODULE = torch.Tensor
|
10 |
+
else:
|
11 |
+
MODULE = torch.autograd.Variable
|
12 |
+
|
13 |
+
|
14 |
+
FP16_FUNCS = [
|
15 |
+
'__matmul__',
|
16 |
+
]
|
17 |
+
|
18 |
+
FP32_FUNCS = [
|
19 |
+
'__ipow__',
|
20 |
+
'__pow__',
|
21 |
+
'__rpow__',
|
22 |
+
|
23 |
+
# Cast to fp32 before transfer to CPU
|
24 |
+
'cpu',
|
25 |
+
]
|
26 |
+
|
27 |
+
CASTS = [
|
28 |
+
'__add__',
|
29 |
+
'__div__',
|
30 |
+
'__eq__',
|
31 |
+
'__ge__',
|
32 |
+
'__gt__',
|
33 |
+
'__iadd__',
|
34 |
+
'__idiv__',
|
35 |
+
'__imul__',
|
36 |
+
'__isub__',
|
37 |
+
'__itruediv__',
|
38 |
+
'__le__',
|
39 |
+
'__lt__',
|
40 |
+
'__mul__',
|
41 |
+
'__ne__',
|
42 |
+
'__radd__',
|
43 |
+
'__rdiv__',
|
44 |
+
'__rmul__',
|
45 |
+
'__rsub__',
|
46 |
+
'__rtruediv__',
|
47 |
+
'__sub__',
|
48 |
+
'__truediv__',
|
49 |
+
]
|
50 |
+
|
51 |
+
# None of these, but here to make code cleaner.
|
52 |
+
SEQUENCE_CASTS = []
|
53 |
+
|
54 |
+
# We need to grab all the methods from torch_overrides and add them to
|
55 |
+
# the Tensor lists as well, as almost all methods are duplicated
|
56 |
+
# between `torch` and `torch.Tensor` (and check with `hasattr`,
|
57 |
+
# because a few random ones aren't defined on Tensor)
|
58 |
+
_self_mod = importlib.import_module(__name__)
|
59 |
+
for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
|
60 |
+
lst = getattr(_self_mod, attrname)
|
61 |
+
for fn in getattr(torch_overrides, attrname):
|
62 |
+
if hasattr(MODULE, fn):
|
63 |
+
lst.append(fn)
|
apex/apex/amp/lists/torch_overrides.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .. import utils
|
4 |
+
|
5 |
+
MODULE = torch
|
6 |
+
|
7 |
+
FP16_FUNCS = [
|
8 |
+
# Low level functions wrapped by torch.nn layers.
|
9 |
+
# The wrapper layers contain the weights which are then passed in as a parameter
|
10 |
+
# to these functions.
|
11 |
+
'conv1d',
|
12 |
+
'conv2d',
|
13 |
+
'conv3d',
|
14 |
+
'conv_transpose1d',
|
15 |
+
'conv_transpose2d',
|
16 |
+
'conv_transpose3d',
|
17 |
+
'conv_tbc',
|
18 |
+
'prelu',
|
19 |
+
|
20 |
+
# BLAS
|
21 |
+
'addmm',
|
22 |
+
'addmv',
|
23 |
+
'addr',
|
24 |
+
'matmul',
|
25 |
+
'mm',
|
26 |
+
'mv',
|
27 |
+
]
|
28 |
+
|
29 |
+
FP32_FUNCS = [
|
30 |
+
# Pointwise
|
31 |
+
'acos',
|
32 |
+
'asin',
|
33 |
+
'cosh',
|
34 |
+
'erfinv',
|
35 |
+
'exp',
|
36 |
+
'expm1',
|
37 |
+
'log',
|
38 |
+
'log10',
|
39 |
+
'log2',
|
40 |
+
'reciprocal',
|
41 |
+
'rsqrt',
|
42 |
+
'sinh',
|
43 |
+
'tan',
|
44 |
+
|
45 |
+
# Other math
|
46 |
+
'pow',
|
47 |
+
|
48 |
+
# Reduction
|
49 |
+
'cumprod',
|
50 |
+
'cumsum',
|
51 |
+
'dist',
|
52 |
+
'mean',
|
53 |
+
'norm',
|
54 |
+
'prod',
|
55 |
+
'std',
|
56 |
+
'sum',
|
57 |
+
'var',
|
58 |
+
|
59 |
+
# Misc
|
60 |
+
'renorm'
|
61 |
+
]
|
62 |
+
|
63 |
+
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
|
64 |
+
# check the CUDA version -- if at least 9.1, then put the bmm
|
65 |
+
# functions on the fp16 list. Otherwise, put them on the fp32 list.
|
66 |
+
_bmms = ['addbmm',
|
67 |
+
'baddbmm',
|
68 |
+
'bmm']
|
69 |
+
if utils.get_cuda_version() >= (9, 1, 0):
|
70 |
+
FP16_FUNCS.extend(_bmms)
|
71 |
+
else:
|
72 |
+
FP32_FUNCS.extend(_bmms)
|
73 |
+
|
74 |
+
# Multi-tensor fns that may need type promotion
|
75 |
+
CASTS = [
|
76 |
+
# Multi-tensor math
|
77 |
+
'addcdiv',
|
78 |
+
'addcmul',
|
79 |
+
'atan2',
|
80 |
+
'cross',
|
81 |
+
'bilinear',
|
82 |
+
|
83 |
+
# Element-wise _or_ tensor-wise math
|
84 |
+
'add',
|
85 |
+
'div',
|
86 |
+
'mul',
|
87 |
+
|
88 |
+
# Comparison
|
89 |
+
'eq',
|
90 |
+
'equal',
|
91 |
+
'ge',
|
92 |
+
'gt',
|
93 |
+
'le',
|
94 |
+
'lt',
|
95 |
+
'ne'
|
96 |
+
]
|
97 |
+
|
98 |
+
# Functions that take sequence arguments. We need to inspect the whole
|
99 |
+
# sequence and cast to the widest type.
|
100 |
+
SEQUENCE_CASTS = [
|
101 |
+
'cat',
|
102 |
+
'stack'
|
103 |
+
]
|
apex/apex/amp/opt.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
from .scaler import LossScaler, master_params
|
5 |
+
from ._amp_state import maybe_print
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class OptimWrapper(object):
|
10 |
+
def __init__(self, optimizer, amp_handle, num_loss):
|
11 |
+
self._optimizer = optimizer
|
12 |
+
self._amp_handle = amp_handle
|
13 |
+
self._num_loss = num_loss
|
14 |
+
self._loss_idx = 0
|
15 |
+
self._skip_next = [False] * num_loss
|
16 |
+
self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)]
|
17 |
+
|
18 |
+
@contextlib.contextmanager
|
19 |
+
def scale_loss(self, loss):
|
20 |
+
if not self._amp_handle.is_active():
|
21 |
+
yield loss
|
22 |
+
return
|
23 |
+
|
24 |
+
# When there are multiple losses per-optimizer, we need
|
25 |
+
# to save out current grad accumulation, since we won't be
|
26 |
+
# able to unscale this particulare loss once the grads are
|
27 |
+
# all mixed together.
|
28 |
+
cached_grads = []
|
29 |
+
if self._loss_idx > 0:
|
30 |
+
for p in master_params(self._optimizer):
|
31 |
+
if p.grad is not None:
|
32 |
+
cached_grads.append(p.grad.data.detach().clone())
|
33 |
+
else:
|
34 |
+
cached_grads.append(None)
|
35 |
+
self._optimizer.zero_grad()
|
36 |
+
|
37 |
+
loss_scale = self._cur_loss_scaler().loss_scale()
|
38 |
+
yield loss * loss_scale
|
39 |
+
|
40 |
+
self._cur_loss_scaler().clear_overflow_state()
|
41 |
+
self._cur_loss_scaler().unscale(
|
42 |
+
master_params(self._optimizer),
|
43 |
+
master_params(self._optimizer),
|
44 |
+
loss_scale)
|
45 |
+
self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()
|
46 |
+
self._loss_idx += 1
|
47 |
+
|
48 |
+
if len(cached_grads) > 0:
|
49 |
+
for p, cached_grad in zip(master_params(self._optimizer),
|
50 |
+
cached_grads):
|
51 |
+
if cached_grad is not None:
|
52 |
+
p.grad.data.add_(cached_grad)
|
53 |
+
cached_grads = []
|
54 |
+
|
55 |
+
def _cur_loss_scaler(self):
|
56 |
+
assert 0 <= self._loss_idx < self._num_loss
|
57 |
+
return self._loss_scaler[self._loss_idx]
|
58 |
+
|
59 |
+
def step(self, closure=None):
|
60 |
+
if not self._amp_handle.is_active():
|
61 |
+
return self._optimizer.step(closure=closure)
|
62 |
+
|
63 |
+
self._loss_idx = 0
|
64 |
+
|
65 |
+
for group in self._optimizer.param_groups:
|
66 |
+
for p in group['params']:
|
67 |
+
self._amp_handle.remove_cache(p)
|
68 |
+
|
69 |
+
if closure is not None:
|
70 |
+
raise NotImplementedError(
|
71 |
+
'The `closure` argument is unsupported by the amp ' +
|
72 |
+
'optimizer wrapper.')
|
73 |
+
if any(self._skip_next):
|
74 |
+
maybe_print('Gradient overflow, skipping update')
|
75 |
+
self._skip_next = [False] * self._num_loss
|
76 |
+
else:
|
77 |
+
return self._optimizer.step(closure=closure)
|
78 |
+
|
79 |
+
# Forward any attribute lookups
|
80 |
+
def __getattr__(self, attr):
|
81 |
+
return getattr(self._optimizer, attr)
|
82 |
+
|
83 |
+
# Forward all torch.optim.Optimizer methods
|
84 |
+
def __getstate__(self):
|
85 |
+
return self._optimizer.__getstate__()
|
86 |
+
|
87 |
+
def __setstate__(self):
|
88 |
+
return self._optimizer.__setstate__()
|
89 |
+
|
90 |
+
def __repr__(self):
|
91 |
+
return self._optimizer.__repr__()
|
92 |
+
|
93 |
+
def state_dict(self):
|
94 |
+
return self._optimizer.state_dict()
|
95 |
+
|
96 |
+
def load_state_dict(self, state_dict):
|
97 |
+
return self._optimizer.load_state_dict(state_dict)
|
98 |
+
|
99 |
+
def zero_grad(self):
|
100 |
+
return self._optimizer.zero_grad()
|
101 |
+
|
102 |
+
def add_param_group(self, param_group):
|
103 |
+
return self._optimizer.add_param_group(param_group)
|
apex/apex/amp/rnn_compat.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import utils, wrap
|
2 |
+
|
3 |
+
import torch
|
4 |
+
_VF = torch._C._VariableFunctions
|
5 |
+
RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
|
6 |
+
|
7 |
+
def _gen_VF_wrapper(name):
|
8 |
+
def wrapper(*args, **kwargs):
|
9 |
+
return getattr(_VF, name)(*args, **kwargs)
|
10 |
+
return wrapper
|
11 |
+
|
12 |
+
# Some python magic to generate an object that has the rnn cell functions
|
13 |
+
# defined on it, all of which call into corresponding _VF version.
|
14 |
+
# Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF"
|
15 |
+
# imported at module scope within torch.nn.modules.rnn). This should
|
16 |
+
# not affect third-party importers of _VF.py.
|
17 |
+
class VariableFunctionsShim(object):
|
18 |
+
def __init__(self):
|
19 |
+
for name in RNN_NAMES:
|
20 |
+
for suffix in ['', '_cell']:
|
21 |
+
fn_name = name + suffix
|
22 |
+
setattr(self, fn_name, _gen_VF_wrapper(fn_name))
|
23 |
+
|
24 |
+
def has_old_rnns():
|
25 |
+
try:
|
26 |
+
torch.nn.backends.thnn.backend.LSTMCell
|
27 |
+
return True
|
28 |
+
except:
|
29 |
+
return False
|
30 |
+
|
31 |
+
def whitelist_rnn_cells(handle, verbose):
|
32 |
+
# Different module + function names in old/new RNN cases
|
33 |
+
if has_old_rnns():
|
34 |
+
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
|
35 |
+
mod = torch.nn.backends.thnn.backend
|
36 |
+
else:
|
37 |
+
fn_names = [x + '_cell' for x in RNN_NAMES]
|
38 |
+
mod = torch.nn.modules.rnn._VF
|
39 |
+
assert isinstance(mod, VariableFunctionsShim)
|
40 |
+
|
41 |
+
# Insert casts on cell functions
|
42 |
+
for fn in fn_names:
|
43 |
+
wrap.cached_cast(mod, fn, utils.maybe_half, handle,
|
44 |
+
try_caching=True, verbose=verbose)
|
45 |
+
|
46 |
+
if has_old_rnns():
|
47 |
+
# Special handling of `backward` for fused gru / lstm:
|
48 |
+
# The `backward` method calls Tensor.sum() (blacklist) internally,
|
49 |
+
# and then the resulting grad_input has the wrong type.
|
50 |
+
# TODO: where else is this a problem?
|
51 |
+
for rnn_type in ['GRUFused', 'LSTMFused']:
|
52 |
+
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
|
53 |
+
wrap.disable_casts(mod, 'backward', handle)
|
apex/apex/amp/scaler.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from ..multi_tensor_apply import multi_tensor_applier
|
3 |
+
from ._amp_state import _amp_state, master_params, maybe_print
|
4 |
+
from itertools import product
|
5 |
+
|
6 |
+
def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
|
7 |
+
# Exception handling for 18.04 compatibility
|
8 |
+
if check_overflow:
|
9 |
+
cpu_sum = float(model_grad.float().sum())
|
10 |
+
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
11 |
+
return True
|
12 |
+
|
13 |
+
if master_grad is not model_grad: # copy_ probably internally short-circuits this
|
14 |
+
master_grad.copy_(model_grad)
|
15 |
+
if scale != 1.0:
|
16 |
+
master_grad.mul_(scale)
|
17 |
+
return False
|
18 |
+
|
19 |
+
def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, check_overflow=False):
|
20 |
+
# Exception handling for 18.04 compatibility
|
21 |
+
if check_overflow:
|
22 |
+
cpu_sum = float(model_grad.float().sum())
|
23 |
+
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
24 |
+
return True
|
25 |
+
|
26 |
+
# if master_grad is not model_grad: # copy_ probably internally short-circuits this
|
27 |
+
# master_grad.copy_(model_grad)
|
28 |
+
assert stashed_grad.dtype == master_grad.dtype
|
29 |
+
converted_model_grad = model_grad.to(master_grad.dtype)
|
30 |
+
stashed_grad.add_(scale, converted_model_grad)
|
31 |
+
master_grad.data = stashed_grad.data
|
32 |
+
return False
|
33 |
+
|
34 |
+
class LossScaler(object):
|
35 |
+
warned_no_fused_kernel = False
|
36 |
+
warned_unscaling_non_fp32_grad = False
|
37 |
+
has_fused_kernel = False
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
loss_scale,
|
41 |
+
init_scale=2.**16,
|
42 |
+
scale_factor=2.,
|
43 |
+
scale_window=2000,
|
44 |
+
min_loss_scale=None,
|
45 |
+
max_loss_scale=2.**24):
|
46 |
+
if loss_scale == "dynamic":
|
47 |
+
self.dynamic = True
|
48 |
+
self._loss_scale = init_scale
|
49 |
+
else:
|
50 |
+
self.dynamic = False
|
51 |
+
self._loss_scale = loss_scale
|
52 |
+
self._max_loss_scale = max_loss_scale
|
53 |
+
self._min_loss_scale = min_loss_scale
|
54 |
+
self._scale_seq_len = scale_window
|
55 |
+
self._unskipped = 0
|
56 |
+
self._has_overflow = False
|
57 |
+
self._overflow_buf = torch.cuda.IntTensor([0])
|
58 |
+
if multi_tensor_applier.available:
|
59 |
+
import amp_C
|
60 |
+
LossScaler.has_fused_kernel = multi_tensor_applier.available
|
61 |
+
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
|
62 |
+
LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
|
63 |
+
else:
|
64 |
+
if not LossScaler.warned_no_fused_kernel:
|
65 |
+
maybe_print(
|
66 |
+
"Warning: multi_tensor_applier fused unscale kernel is unavailable, "
|
67 |
+
"possibly because apex was installed without --cuda_ext --cpp_ext. "
|
68 |
+
"Using Python fallback. Original ImportError was: " +
|
69 |
+
repr(multi_tensor_applier.import_err),
|
70 |
+
True)
|
71 |
+
LossScaler.has_fused_kernel = False
|
72 |
+
LossScaler.warned_no_fused_kernel = True
|
73 |
+
|
74 |
+
def loss_scale(self):
|
75 |
+
return self._loss_scale
|
76 |
+
|
77 |
+
def unscale_python(self, model_grads, master_grads, scale):
|
78 |
+
for model, master in zip(model_grads, master_grads):
|
79 |
+
if model is not None:
|
80 |
+
if not LossScaler.warned_unscaling_non_fp32_grad:
|
81 |
+
if master.dtype != torch.float32:
|
82 |
+
maybe_print(
|
83 |
+
"Attempting to unscale a grad with type {} ".format(master.type()) +
|
84 |
+
"Unscaling non-fp32 grads may indicate an error. "
|
85 |
+
"When using Amp, you don't need to call .half() on your model.")
|
86 |
+
LossScaler.warned_unscaling_non_fp32_grad = True
|
87 |
+
self._has_overflow = scale_check_overflow_python(model,
|
88 |
+
master,
|
89 |
+
1./scale,
|
90 |
+
self.dynamic)
|
91 |
+
if self._has_overflow and self.dynamic:
|
92 |
+
break
|
93 |
+
|
94 |
+
# unused_scale keeps some of the old API alive for hopefully a short time.
|
95 |
+
def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False):
|
96 |
+
if self._has_overflow:
|
97 |
+
return
|
98 |
+
|
99 |
+
scale = self._loss_scale
|
100 |
+
|
101 |
+
if scale == 1.0 and models_are_masters and not self.dynamic:
|
102 |
+
return
|
103 |
+
|
104 |
+
if LossScaler.has_fused_kernel:
|
105 |
+
# if (not LossScaler.warned_unscaling_non_fp32_grad
|
106 |
+
# and master_grads[0].dtype == torch.float16):
|
107 |
+
# print("Warning: unscaling grads that are not FP32. "
|
108 |
+
# "Unscaling non-fp32 grads may indicate an error. "
|
109 |
+
# "When using Amp, you don't need to call .half() on your model.")
|
110 |
+
# # Setting this to True unconditionally allows the possibility of an escape
|
111 |
+
# # if never-before-seen non-fp32 grads are created in some later iteration.
|
112 |
+
# LossScaler.warned_unscaling_non_fp32_grad = True
|
113 |
+
multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,
|
114 |
+
self._overflow_buf,
|
115 |
+
[model_grads, master_grads],
|
116 |
+
1./scale)
|
117 |
+
else:
|
118 |
+
self.unscale_python(model_grads, master_grads, scale)
|
119 |
+
|
120 |
+
# Defer to update_scale
|
121 |
+
# If the fused kernel is available, we only need one D2H memcopy and sync.
|
122 |
+
# if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
|
123 |
+
# self._has_overflow = self._overflow_buf.item()
|
124 |
+
|
125 |
+
def unscale_with_stashed_python(self,
|
126 |
+
model_grads,
|
127 |
+
stashed_master_grads,
|
128 |
+
master_grads,
|
129 |
+
scale):
|
130 |
+
for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
|
131 |
+
if model is None and stashed is None:
|
132 |
+
continue
|
133 |
+
else:
|
134 |
+
if not LossScaler.warned_unscaling_non_fp32_grad:
|
135 |
+
if master.dtype != torch.float32:
|
136 |
+
maybe_print(
|
137 |
+
"Attempting to unscale a grad with type {} ".format(master.type()) +
|
138 |
+
"Unscaling non-fp32 grads may indicate an error. "
|
139 |
+
"When using Amp, you don't need to call .half() on your model.")
|
140 |
+
LossScaler.warned_unscaling_non_fp32_grad = True
|
141 |
+
self._has_overflow = axpby_check_overflow_python(model,
|
142 |
+
stashed,
|
143 |
+
master,
|
144 |
+
1./scale,
|
145 |
+
self.dynamic)
|
146 |
+
if self._has_overflow and self.dynamic:
|
147 |
+
break
|
148 |
+
|
149 |
+
def unscale_with_stashed(self,
|
150 |
+
model_grads,
|
151 |
+
stashed_master_grads,
|
152 |
+
master_grads):
|
153 |
+
if self._has_overflow:
|
154 |
+
return
|
155 |
+
|
156 |
+
scale = self._loss_scale
|
157 |
+
|
158 |
+
if LossScaler.has_fused_kernel:
|
159 |
+
if (not LossScaler.warned_unscaling_non_fp32_grad
|
160 |
+
and master_grads[0].dtype == torch.float16):
|
161 |
+
print("Warning: unscaling grads that are not FP32. "
|
162 |
+
"Unscaling non-fp32 grads may indicate an error. "
|
163 |
+
"When using Amp, you don't need to call .half() on your model.")
|
164 |
+
# Setting this to True unconditionally allows the possibility of an escape
|
165 |
+
# if never-before-seen non-fp32 grads are created in some later iteration.
|
166 |
+
LossScaler.warned_unscaling_non_fp32_grad = True
|
167 |
+
multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
|
168 |
+
self._overflow_buf,
|
169 |
+
[model_grads, stashed_master_grads, master_grads],
|
170 |
+
1./scale,
|
171 |
+
1.0,
|
172 |
+
0) # check only arg 0, aka the incoming model grads, for infs
|
173 |
+
else:
|
174 |
+
self.unscale_with_stashed_python(model_grads,
|
175 |
+
stashed_master_grads,
|
176 |
+
master_grads,
|
177 |
+
scale)
|
178 |
+
|
179 |
+
# Defer to update_scale
|
180 |
+
# If the fused kernel is available, we only need one D2H memcopy and sync.
|
181 |
+
# if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
|
182 |
+
# self._has_overflow = self._overflow_buf.item()
|
183 |
+
|
184 |
+
def clear_overflow_state(self):
|
185 |
+
self._has_overflow = False
|
186 |
+
if self.has_fused_kernel:
|
187 |
+
self._overflow_buf.zero_()
|
188 |
+
|
189 |
+
# Separate so unscale() can be called more that once before updating.
|
190 |
+
def update_scale(self):
|
191 |
+
# If the fused kernel is available, we only need one D2H memcopy and sync.
|
192 |
+
if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
|
193 |
+
self._has_overflow = self._overflow_buf.item()
|
194 |
+
|
195 |
+
if self._has_overflow and self.dynamic:
|
196 |
+
should_skip = True
|
197 |
+
if(self._min_loss_scale):
|
198 |
+
self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
|
199 |
+
else:
|
200 |
+
self._loss_scale = self._loss_scale/2.
|
201 |
+
self._unskipped = 0
|
202 |
+
else:
|
203 |
+
should_skip = False
|
204 |
+
self._unskipped += 1
|
205 |
+
|
206 |
+
if self._unskipped == self._scale_seq_len and self.dynamic:
|
207 |
+
self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)
|
208 |
+
self._unskipped = 0
|
209 |
+
|
210 |
+
return should_skip
|
apex/apex/amp/utils.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import compat
|
2 |
+
|
3 |
+
import functools
|
4 |
+
import itertools
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def get_cuda_version():
|
9 |
+
return tuple(int(x) for x in torch.version.cuda.split('.'))
|
10 |
+
|
11 |
+
def is_fp_tensor(x):
|
12 |
+
if is_nested(x):
|
13 |
+
# Fast-fail version of all(is_fp_tensor)
|
14 |
+
for y in x:
|
15 |
+
if not is_fp_tensor(y):
|
16 |
+
return False
|
17 |
+
return True
|
18 |
+
return compat.is_tensor_like(x) and compat.is_floating_point(x)
|
19 |
+
|
20 |
+
def is_nested(x):
|
21 |
+
return isinstance(x, tuple) or isinstance(x, list)
|
22 |
+
|
23 |
+
def should_cache(x):
|
24 |
+
if is_nested(x):
|
25 |
+
# Fast-fail version of all(should_cache)
|
26 |
+
for y in x:
|
27 |
+
if not should_cache(y):
|
28 |
+
return False
|
29 |
+
return True
|
30 |
+
return isinstance(x, torch.nn.parameter.Parameter) and \
|
31 |
+
type_string(x) == 'FloatTensor'
|
32 |
+
|
33 |
+
def collect_fp_tensor_types(args, kwargs):
|
34 |
+
def collect_types(x, types):
|
35 |
+
if is_nested(x):
|
36 |
+
for y in x:
|
37 |
+
collect_types(y, types)
|
38 |
+
else:
|
39 |
+
types.add(type_string(x))
|
40 |
+
|
41 |
+
all_args = itertools.chain(args, kwargs.values())
|
42 |
+
types = set()
|
43 |
+
for x in all_args:
|
44 |
+
if is_fp_tensor(x):
|
45 |
+
collect_types(x, types)
|
46 |
+
return types
|
47 |
+
|
48 |
+
def type_string(x):
|
49 |
+
return x.type().split('.')[-1]
|
50 |
+
|
51 |
+
def maybe_half(x, name='', verbose=False):
|
52 |
+
if is_nested(x):
|
53 |
+
return type(x)([maybe_half(y) for y in x])
|
54 |
+
|
55 |
+
if not x.is_cuda or type_string(x) == 'HalfTensor':
|
56 |
+
return x
|
57 |
+
else:
|
58 |
+
if verbose:
|
59 |
+
print('Float->Half ({})'.format(name))
|
60 |
+
return x.half()
|
61 |
+
|
62 |
+
def maybe_float(x, name='', verbose=False):
|
63 |
+
if is_nested(x):
|
64 |
+
return type(x)([maybe_float(y) for y in x])
|
65 |
+
|
66 |
+
if not x.is_cuda or type_string(x) == 'FloatTensor':
|
67 |
+
return x
|
68 |
+
else:
|
69 |
+
if verbose:
|
70 |
+
print('Half->Float ({})'.format(name))
|
71 |
+
return x.float()
|
72 |
+
|
73 |
+
# NB: returneds casted `args`, mutates `kwargs` in-place
|
74 |
+
def casted_args(cast_fn, args, kwargs):
|
75 |
+
new_args = []
|
76 |
+
for x in args:
|
77 |
+
if is_fp_tensor(x):
|
78 |
+
new_args.append(cast_fn(x))
|
79 |
+
else:
|
80 |
+
new_args.append(x)
|
81 |
+
for k in kwargs:
|
82 |
+
val = kwargs[k]
|
83 |
+
if is_fp_tensor(val):
|
84 |
+
kwargs[k] = cast_fn(val)
|
85 |
+
return new_args
|
86 |
+
|
87 |
+
def cached_cast(cast_fn, x, cache):
|
88 |
+
if is_nested(x):
|
89 |
+
return type(x)([cached_cast(y) for y in x])
|
90 |
+
if x in cache:
|
91 |
+
cached_x = cache[x]
|
92 |
+
if x.requires_grad and cached_x.requires_grad:
|
93 |
+
# Make sure x is actually cached_x's autograd parent.
|
94 |
+
if cached_x.grad_fn.next_functions[1][0].variable is not x:
|
95 |
+
raise RuntimeError("x and cache[x] both require grad, but x is not "
|
96 |
+
"cache[x]'s parent. This is likely an error.")
|
97 |
+
# During eval, it's possible to end up caching casted weights with
|
98 |
+
# requires_grad=False. On the next training iter, if cached_x is found
|
99 |
+
# and reused from the cache, it will not actually have x as its parent.
|
100 |
+
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
|
101 |
+
# if x.requires_grad and cached_x.requires_grad do not match.
|
102 |
+
#
|
103 |
+
# During eval (i.e. running under with torch.no_grad()) the invalidation
|
104 |
+
# check would cause the cached value to be dropped every time, because
|
105 |
+
# cached_x would always be created with requires_grad=False, while x would
|
106 |
+
# still have requires_grad=True. This would render the cache effectively
|
107 |
+
# useless during eval. Therefore, if we are running under the no_grad()
|
108 |
+
# context manager (torch.is_grad_enabled=False) we elide the invalidation
|
109 |
+
# check, and use the cached value even though its requires_grad flag doesn't
|
110 |
+
# match. During eval, we don't care that there's no autograd-graph
|
111 |
+
# connection between x and cached_x.
|
112 |
+
if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
|
113 |
+
del cache[x]
|
114 |
+
else:
|
115 |
+
return cached_x
|
116 |
+
|
117 |
+
casted_x = cast_fn(x)
|
118 |
+
cache[x] = casted_x
|
119 |
+
return casted_x
|
120 |
+
|
121 |
+
def verbosify(cast_fn, fn_name, verbose):
|
122 |
+
if verbose:
|
123 |
+
return functools.partial(cast_fn, name=fn_name, verbose=verbose)
|
124 |
+
else:
|
125 |
+
return cast_fn
|
126 |
+
|
127 |
+
def as_inplace(fns):
|
128 |
+
for x in fns:
|
129 |
+
yield x + '_'
|
130 |
+
|
131 |
+
def has_func(mod, fn):
|
132 |
+
if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
|
133 |
+
return fn in mod.function_classes
|
134 |
+
elif isinstance(mod, dict):
|
135 |
+
return fn in mod
|
136 |
+
else:
|
137 |
+
return hasattr(mod, fn)
|
138 |
+
|
139 |
+
def get_func(mod, fn):
|
140 |
+
if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
|
141 |
+
return mod.function_classes[fn]
|
142 |
+
elif isinstance(mod, dict):
|
143 |
+
return mod[fn]
|
144 |
+
else:
|
145 |
+
return getattr(mod, fn)
|
146 |
+
|
147 |
+
def set_func(mod, fn, new_fn):
|
148 |
+
if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
|
149 |
+
mod.function_classes[fn] = new_fn
|
150 |
+
elif isinstance(mod, dict):
|
151 |
+
mod[fn] = new_fn
|
152 |
+
else:
|
153 |
+
setattr(mod, fn, new_fn)
|
154 |
+
|
155 |
+
def set_func_save(handle, mod, fn, new_fn):
|
156 |
+
cur_fn = get_func(mod, fn)
|
157 |
+
handle._save_func(mod, fn, cur_fn)
|
158 |
+
set_func(mod, fn, new_fn)
|
159 |
+
|
160 |
+
# A couple problems get solved here:
|
161 |
+
# - The flat_weight buffer is disconnected from autograd graph,
|
162 |
+
# so the fp16 weights need to be derived from the input weights
|
163 |
+
# to this forward call, not the flat buffer.
|
164 |
+
# - The ordering of weights in the flat buffer is...idiosyncratic.
|
165 |
+
# First problem is solved with combination of set_ (to set up
|
166 |
+
# correct storage) and copy_ (so the fp16 weight derives from the
|
167 |
+
# fp32 one in autograd.
|
168 |
+
# Second is solved by doing ptr arithmetic on the fp32 weights
|
169 |
+
# to derive the correct offset.
|
170 |
+
#
|
171 |
+
# TODO: maybe this should actually use
|
172 |
+
# `torch._cudnn_rnn_flatten_weight`? But then I need to call
|
173 |
+
# on first iter and cache the right offsets. Ugh.
|
174 |
+
def synthesize_flattened_rnn_weights(fp32_weights,
|
175 |
+
fp16_flat_tensor,
|
176 |
+
rnn_fn='',
|
177 |
+
verbose=False):
|
178 |
+
fp16_weights = []
|
179 |
+
fp32_base_ptr = fp32_weights[0][0].data_ptr()
|
180 |
+
for layer_weights in fp32_weights:
|
181 |
+
fp16_layer_weights = []
|
182 |
+
for w_fp32 in layer_weights:
|
183 |
+
w_fp16 = w_fp32.new().half()
|
184 |
+
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
|
185 |
+
w_fp16.set_(fp16_flat_tensor.storage(),
|
186 |
+
offset,
|
187 |
+
w_fp32.shape)
|
188 |
+
w_fp16.copy_(w_fp32)
|
189 |
+
if verbose:
|
190 |
+
print('Float->Half ({})'.format(rnn_fn))
|
191 |
+
fp16_layer_weights.append(w_fp16)
|
192 |
+
fp16_weights.append(fp16_layer_weights)
|
193 |
+
return fp16_weights
|
194 |
+
|
195 |
+
# Roughly same as above, just the `fp32_weights` aren't nested.
|
196 |
+
# Code kept separate for readability.
|
197 |
+
def new_synthesize_flattened_rnn_weights(fp32_weights,
|
198 |
+
fp16_flat_tensor,
|
199 |
+
rnn_fn='',
|
200 |
+
verbose=False):
|
201 |
+
fp16_weights = []
|
202 |
+
fp32_base_ptr = fp32_weights[0].data_ptr()
|
203 |
+
for w_fp32 in fp32_weights:
|
204 |
+
w_fp16 = w_fp32.new().half()
|
205 |
+
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
|
206 |
+
w_fp16.set_(fp16_flat_tensor.storage(),
|
207 |
+
offset,
|
208 |
+
w_fp32.shape)
|
209 |
+
w_fp16.copy_(w_fp32)
|
210 |
+
if verbose:
|
211 |
+
print('Float->Half ({})'.format(rnn_fn))
|
212 |
+
fp16_weights.append(w_fp16)
|
213 |
+
return fp16_weights
|
apex/apex/amp/wrap.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import compat
|
2 |
+
from . import utils
|
3 |
+
from ._amp_state import _amp_state
|
4 |
+
from . import rnn_compat
|
5 |
+
|
6 |
+
import functools
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
def make_cast_wrapper(orig_fn, cast_fn, handle,
|
11 |
+
try_caching=False):
|
12 |
+
@functools.wraps(orig_fn)
|
13 |
+
def wrapper(*args, **kwargs):
|
14 |
+
if not handle.is_active():
|
15 |
+
return orig_fn(*args, **kwargs)
|
16 |
+
|
17 |
+
if try_caching and handle.has_cache:
|
18 |
+
args = list(args)
|
19 |
+
for i in range(len(args)):
|
20 |
+
if utils.should_cache(args[i]):
|
21 |
+
args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)
|
22 |
+
for k in kwargs:
|
23 |
+
if utils.should_cache(kwargs[k]):
|
24 |
+
kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)
|
25 |
+
new_args = utils.casted_args(cast_fn,
|
26 |
+
args,
|
27 |
+
kwargs)
|
28 |
+
return orig_fn(*new_args, **kwargs)
|
29 |
+
return wrapper
|
30 |
+
|
31 |
+
def cached_cast(mod, fn, cast_fn, handle,
|
32 |
+
try_caching=False, verbose=False):
|
33 |
+
if not utils.has_func(mod, fn):
|
34 |
+
return
|
35 |
+
|
36 |
+
orig_fn = utils.get_func(mod, fn)
|
37 |
+
cast_fn = utils.verbosify(cast_fn, fn, verbose)
|
38 |
+
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
|
39 |
+
utils.set_func_save(handle, mod, fn, wrapper)
|
40 |
+
|
41 |
+
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
|
42 |
+
# Annoyingly, make_promote_wrapper still uses the global handle. Once everyone
|
43 |
+
# is on the new API and I am free to get rid of handle, I can clean this up.
|
44 |
+
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
|
45 |
+
@functools.wraps(orig_fn)
|
46 |
+
def wrapper(*args, **kwargs):
|
47 |
+
if not _amp_state.handle.is_active():
|
48 |
+
return orig_fn(*args, **kwargs)
|
49 |
+
|
50 |
+
types = utils.collect_fp_tensor_types(args, kwargs)
|
51 |
+
|
52 |
+
if len(types) <= 1:
|
53 |
+
return orig_fn(*args, **kwargs)
|
54 |
+
elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
|
55 |
+
new_args = utils.casted_args(cast_fn,
|
56 |
+
args,
|
57 |
+
kwargs)
|
58 |
+
return orig_fn(*new_args, **kwargs)
|
59 |
+
else:
|
60 |
+
raise NotImplementedError('Do not know how to handle ' +
|
61 |
+
'these types to promote: {}'
|
62 |
+
.format(types))
|
63 |
+
return wrapper
|
64 |
+
|
65 |
+
def promote(mod, fn, handle, verbose=False):
|
66 |
+
orig_fn = utils.get_func(mod, fn)
|
67 |
+
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
|
68 |
+
wrapper = make_promote_wrapper(orig_fn, maybe_float)
|
69 |
+
utils.set_func_save(handle, mod, fn, wrapper)
|
70 |
+
|
71 |
+
def sequence_promote(mod, fn, handle, verbose=False):
|
72 |
+
orig_fn = utils.get_func(mod, fn)
|
73 |
+
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
|
74 |
+
@functools.wraps(orig_fn)
|
75 |
+
def wrapper(seq, *args, **kwargs):
|
76 |
+
if not _amp_state.handle.is_active():
|
77 |
+
return orig_fn(seq, *args, **kwargs)
|
78 |
+
|
79 |
+
types = set([utils.type_string(x) for x in seq])
|
80 |
+
if len(types) <= 1:
|
81 |
+
return orig_fn(seq, *args, **kwargs)
|
82 |
+
elif types == set(['HalfTensor', 'FloatTensor']):
|
83 |
+
cast_seq = utils.casted_args(maybe_float,
|
84 |
+
seq, {})
|
85 |
+
return orig_fn(cast_seq, *args, **kwargs)
|
86 |
+
else:
|
87 |
+
# TODO: other mixed-type cases aren't due to amp.
|
88 |
+
# Just pass through?
|
89 |
+
return orig_fn(seq, *args, **kwargs)
|
90 |
+
utils.set_func_save(handle, mod, fn, wrapper)
|
91 |
+
|
92 |
+
def promote_match_arg0(mod, fn, handle, verbose=False):
|
93 |
+
if not utils.has_func(mod, fn):
|
94 |
+
return
|
95 |
+
|
96 |
+
orig_fn = utils.get_func(mod, fn)
|
97 |
+
@functools.wraps(orig_fn)
|
98 |
+
def wrapper(arg0, *args, **kwargs):
|
99 |
+
assert compat.is_tensor_like(arg0)
|
100 |
+
if not _amp_state.handle.is_active():
|
101 |
+
return orig_fn(arg0, *args, **kwargs)
|
102 |
+
|
103 |
+
if utils.type_string(arg0) == 'HalfTensor':
|
104 |
+
cast_fn = utils.maybe_half
|
105 |
+
elif utils.type_string(arg0) == 'FloatTensor':
|
106 |
+
cast_fn = utils.maybe_float
|
107 |
+
else:
|
108 |
+
return orig_fn(arg0, *args, **kwargs)
|
109 |
+
cast_fn = utils.verbosify(cast_fn, fn, verbose)
|
110 |
+
new_args = utils.casted_args(cast_fn, args, kwargs)
|
111 |
+
return orig_fn(arg0, *new_args, **kwargs)
|
112 |
+
utils.set_func_save(handle, mod, fn, wrapper)
|
113 |
+
|
114 |
+
def err_if_any_half(mod, fn, handle, custom_err_msg=None):
|
115 |
+
if not utils.has_func(mod, fn):
|
116 |
+
return
|
117 |
+
|
118 |
+
orig_fn = utils.get_func(mod, fn)
|
119 |
+
@functools.wraps(orig_fn)
|
120 |
+
def wrapper(*args, **kwargs):
|
121 |
+
types = utils.collect_fp_tensor_types(args, kwargs)
|
122 |
+
if 'HalfTensor' in types:
|
123 |
+
if custom_err_msg:
|
124 |
+
raise NotImplementedError(custom_err_msg)
|
125 |
+
else:
|
126 |
+
raise NotImplementedError('Cannot call in-place function ' +
|
127 |
+
'{} with fp16 arguments.'.format(fn))
|
128 |
+
else:
|
129 |
+
return orig_fn(*args, **kwargs)
|
130 |
+
utils.set_func_save(handle, mod, fn, wrapper)
|
131 |
+
|
132 |
+
def err_if_arg0_half(mod, fn, handle, verbose=False):
|
133 |
+
if not utils.has_func(mod, fn):
|
134 |
+
return
|
135 |
+
|
136 |
+
orig_fn = utils.get_func(mod, fn)
|
137 |
+
@functools.wraps(orig_fn)
|
138 |
+
def wrapper(arg0, *args, **kwargs):
|
139 |
+
assert compat.is_tensor_like(arg0)
|
140 |
+
if utils.type_string(arg0) == 'HalfTensor':
|
141 |
+
raise NotImplementedError('Cannot call in-place method ' +
|
142 |
+
'{} on fp16 Tensors.'.format(fn))
|
143 |
+
else:
|
144 |
+
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
|
145 |
+
new_args = utils.casted_args(cast_fn, args, kwargs)
|
146 |
+
return orig_fn(arg0, *new_args, **kwargs)
|
147 |
+
utils.set_func_save(handle, mod, fn, wrapper)
|
148 |
+
|
149 |
+
# Current RNN approach:
|
150 |
+
# - Wrap top-level `RNN` function in thnn backend
|
151 |
+
# - Will call into either CudnnRNN or AutogradRNN
|
152 |
+
# - Each of these are factory functions that return a per-iter
|
153 |
+
# `forward` function
|
154 |
+
# - We interpose on the factory function to:
|
155 |
+
# 1) Interpose on the actual forward function and put in casts
|
156 |
+
# 2) Insert an fp16 `flat_weight` if necessary
|
157 |
+
def rnn_cast(backend, fn, handle, verbose=False):
|
158 |
+
orig_rnn = utils.get_func(backend, fn)
|
159 |
+
@functools.wraps(orig_rnn)
|
160 |
+
def rnn_wrapper(*args, **kwargs):
|
161 |
+
flat_weight = kwargs.get('flat_weight')
|
162 |
+
if flat_weight is not None:
|
163 |
+
# We replace `flat_weight` with an uninitialized fp16
|
164 |
+
# Tensor. The "actual" weight tensors (provided in `forward`),
|
165 |
+
# will then be set up as ptrs into the buffer and have the
|
166 |
+
# corresponding fp32 values copied in.
|
167 |
+
# We need to call `copy` on the "actual" weights so that the
|
168 |
+
# autograd graph correctly backprops from the wgrads computed
|
169 |
+
# inside cuDNN (on fp16 weights) into the fp32 weights.
|
170 |
+
assert utils.type_string(flat_weight) == 'FloatTensor'
|
171 |
+
if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
|
172 |
+
# Pre-0.4. A little slower, since it zeros out memory.
|
173 |
+
flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
|
174 |
+
else:
|
175 |
+
flat_weight_fp16 = torch.empty_like(flat_weight,
|
176 |
+
dtype=torch.float16)
|
177 |
+
kwargs['flat_weight'] = flat_weight_fp16
|
178 |
+
else:
|
179 |
+
flat_weight_fp16 = None
|
180 |
+
|
181 |
+
forward = orig_rnn(*args, **kwargs)
|
182 |
+
@functools.wraps(forward)
|
183 |
+
def fwd_wrapper(*fargs, **fkwargs):
|
184 |
+
assert len(fargs) == 3 or len(fargs) == 4
|
185 |
+
inputs, weights, hiddens = fargs[:3]
|
186 |
+
assert utils.is_fp_tensor(inputs)
|
187 |
+
assert isinstance(weights, list)
|
188 |
+
cast_fn = utils.verbosify(utils.maybe_half,
|
189 |
+
fn,
|
190 |
+
verbose)
|
191 |
+
new_args = []
|
192 |
+
|
193 |
+
# 0) Inputs
|
194 |
+
new_args.append(cast_fn(inputs))
|
195 |
+
|
196 |
+
# 1) Weights
|
197 |
+
if flat_weight_fp16 is not None:
|
198 |
+
fp16_weights = utils.synthesize_flattened_rnn_weights(
|
199 |
+
weights, flat_weight_fp16, fn, verbose)
|
200 |
+
else:
|
201 |
+
fp16_weights = [[cast_fn(w) for w in layer]
|
202 |
+
for layer in weights]
|
203 |
+
new_args.append(fp16_weights)
|
204 |
+
|
205 |
+
# 2) Inputs: either a tuple (for LSTM) or single tensor
|
206 |
+
if isinstance(hiddens, tuple):
|
207 |
+
new_args.append(tuple(cast_fn(x) for x in hiddens))
|
208 |
+
elif utils.is_fp_tensor(hiddens):
|
209 |
+
new_args.append(cast_fn(hiddens))
|
210 |
+
else:
|
211 |
+
# Hiddens can, in principle, be `None` -- pass through
|
212 |
+
new_args.append(hiddens)
|
213 |
+
|
214 |
+
# 3) Batch sizes (0.4 or later only)
|
215 |
+
if len(fargs) == 4:
|
216 |
+
new_args.append(fargs[3])
|
217 |
+
|
218 |
+
return forward(*new_args, **fkwargs)
|
219 |
+
return fwd_wrapper
|
220 |
+
utils.set_func_save(handle, backend, fn, rnn_wrapper)
|
221 |
+
|
222 |
+
def new_rnn_cast(fn, handle, verbose=False):
|
223 |
+
# Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
|
224 |
+
# For rnn backend calls that route through _rnn_impls, we must patch the ref
|
225 |
+
# that _rnn_impls stashed. For rnn backend calls that directly invoke
|
226 |
+
# _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,
|
227 |
+
# which in turn has patched the ref named "_VF" in torch.nn.modules.rnn.
|
228 |
+
if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):
|
229 |
+
mod = torch.nn.modules.rnn._rnn_impls
|
230 |
+
else:
|
231 |
+
mod = torch.nn.modules.rnn._VF
|
232 |
+
assert isinstance(mod, rnn_compat.VariableFunctionsShim)
|
233 |
+
fn = fn.lower()
|
234 |
+
orig_fn = utils.get_func(mod, fn)
|
235 |
+
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
|
236 |
+
@functools.wraps(orig_fn)
|
237 |
+
def wrapper(*args, **kwargs):
|
238 |
+
# Exact call signature from modules/rnn.py
|
239 |
+
assert len(args) == 9
|
240 |
+
assert len(kwargs) == 0
|
241 |
+
|
242 |
+
if not _amp_state.handle.is_active():
|
243 |
+
return orig_fn(*args, **kwargs)
|
244 |
+
|
245 |
+
if isinstance(args[6], bool):
|
246 |
+
params_idx = 2 # Not PackedSequence case
|
247 |
+
else:
|
248 |
+
params_idx = 3 # PackedSequence case
|
249 |
+
|
250 |
+
new_args = []
|
251 |
+
for i, arg in enumerate(args):
|
252 |
+
if i == params_idx:
|
253 |
+
num_params = sum([x.numel() for x in arg])
|
254 |
+
fp16_weight_buf = args[0].new_empty((num_params,),
|
255 |
+
dtype=torch.half)
|
256 |
+
casted_weights = utils.new_synthesize_flattened_rnn_weights(
|
257 |
+
arg, fp16_weight_buf, fn, verbose)
|
258 |
+
new_args.append(casted_weights)
|
259 |
+
elif utils.is_fp_tensor(arg):
|
260 |
+
new_args.append(cast_fn(arg))
|
261 |
+
else:
|
262 |
+
new_args.append(arg)
|
263 |
+
|
264 |
+
return orig_fn(*new_args)
|
265 |
+
utils.set_func_save(handle, mod, fn, wrapper)
|
266 |
+
|
267 |
+
def disable_casts(mod, fn, handle):
|
268 |
+
if not utils.has_func(mod, fn):
|
269 |
+
return
|
270 |
+
|
271 |
+
orig_fn = utils.get_func(mod, fn)
|
272 |
+
@functools.wraps(orig_fn)
|
273 |
+
def wrapper(*args, **kwargs):
|
274 |
+
with handle._disable_casts():
|
275 |
+
return orig_fn(*args, **kwargs)
|
276 |
+
utils.set_func_save(handle, mod, fn, wrapper)
|
apex/apex/fp16_utils/README.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fp16_optimizer.py contains `FP16_Optimizer`, a Python class designed to wrap an existing Pytorch optimizer and automatically enable master parameters and loss scaling in a manner transparent to the user. To use `FP16_Optimizer`, only two lines of one's Python model need to change.
|
2 |
+
|
3 |
+
#### [FP16_Optimizer API documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling)
|
4 |
+
|
5 |
+
#### [Simple examples with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple)
|
6 |
+
|
7 |
+
#### [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
|
8 |
+
|
9 |
+
#### [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model)
|
10 |
+
|
11 |
+
|
12 |
+
fp16_util.py contains a number of utilities to manually manage master parameters and loss scaling, if the user chooses.
|
13 |
+
|
14 |
+
#### [Manual management documentation](https://nvidia.github.io/apex/fp16_utils.html#manual-master-parameter-management)
|
15 |
+
|
16 |
+
The [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) and [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model) directories also contain `main.py` files that demonstrate manual management of master parameters and static loss scaling. These examples illustrate what sort of operations `FP16_Optimizer` is performing automatically.
|
apex/apex/fp16_utils/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .fp16util import (
|
2 |
+
BN_convert_float,
|
3 |
+
network_to_half,
|
4 |
+
prep_param_lists,
|
5 |
+
model_grads_to_master_grads,
|
6 |
+
master_params_to_model_params,
|
7 |
+
tofp16,
|
8 |
+
to_python_float,
|
9 |
+
clip_grad_norm,
|
10 |
+
convert_module,
|
11 |
+
convert_network,
|
12 |
+
FP16Model,
|
13 |
+
)
|
14 |
+
|
15 |
+
from .fp16_optimizer import FP16_Optimizer
|
16 |
+
from .loss_scaler import LossScaler, DynamicLossScaler
|
apex/apex/fp16_utils/fp16_optimizer.py
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from torch.nn.parameter import Parameter
|
5 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
6 |
+
|
7 |
+
from ..amp._amp_state import _amp_state, maybe_print
|
8 |
+
from ..amp.scaler import LossScaler
|
9 |
+
from ..multi_tensor_apply import multi_tensor_applier
|
10 |
+
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
|
11 |
+
|
12 |
+
# TODO: Update overflow check + downscale to use Carl's fused kernel.
|
13 |
+
class FP16_Optimizer(object):
|
14 |
+
"""
|
15 |
+
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
|
16 |
+
and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
|
17 |
+
For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance,
|
18 |
+
and changing the call to ``backward``.
|
19 |
+
|
20 |
+
Example::
|
21 |
+
|
22 |
+
model = torch.nn.Linear(D_in, D_out).cuda().half()
|
23 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
24 |
+
# Name the FP16_Optimizer instance to replace the existing optimizer
|
25 |
+
# (recommended but not required):
|
26 |
+
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
|
27 |
+
...
|
28 |
+
# loss.backward() becomes:
|
29 |
+
optimizer.backward(loss)
|
30 |
+
...
|
31 |
+
|
32 |
+
Example with dynamic loss scaling::
|
33 |
+
|
34 |
+
...
|
35 |
+
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
36 |
+
# optional arg to control dynamic loss scaling behavior
|
37 |
+
# dynamic_loss_args={'scale_window' : 500})
|
38 |
+
# Usually, dynamic_loss_args is not necessary.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.
|
42 |
+
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
|
43 |
+
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option.
|
44 |
+
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`LossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`LossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`LossScaler`'s defaults will be used.
|
45 |
+
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
|
46 |
+
|
47 |
+
``init_optimizer`` is expected to have been constructed in the ordinary way.
|
48 |
+
It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
|
49 |
+
named to replace ``init_optimizer``, for two reasons:
|
50 |
+
First, it means that references to the same name
|
51 |
+
later in the file will not have to change.
|
52 |
+
Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to
|
53 |
+
modify ``init_optimizer``. If you do choose a unique name for the new
|
54 |
+
:class:`FP16_Optimizer` instance, you should only work with this new instance,
|
55 |
+
because the preexisting optimizer might no longer behave as expected.
|
56 |
+
|
57 |
+
``init_optimizer`` may be any Pytorch optimizer.
|
58 |
+
It may contain a mixture of fp16 and fp32 parameters organized into any number of
|
59 |
+
``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will
|
60 |
+
ingest these ``param_groups`` and remember them.
|
61 |
+
|
62 |
+
Calls to ::
|
63 |
+
|
64 |
+
loss.backward()
|
65 |
+
|
66 |
+
must be replaced with ::
|
67 |
+
|
68 |
+
optimizer.backward(loss)
|
69 |
+
|
70 |
+
because :class:`FP16_Optimizer` requires ownership of the backward pass to implement
|
71 |
+
loss scaling and copies to master gradients.
|
72 |
+
|
73 |
+
.. note::
|
74 |
+
Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
|
75 |
+
are downscaled before being applied. This means that adjusting the loss scale, or using
|
76 |
+
dynamic loss scaling, should not require retuning the learning rate or any other
|
77 |
+
hyperparameters.
|
78 |
+
|
79 |
+
|
80 |
+
**Advanced options**
|
81 |
+
|
82 |
+
**Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure.
|
83 |
+
See docstring for :attr:`step`.
|
84 |
+
|
85 |
+
**Gradient clipping**: Use :attr:`clip_master_grads`.
|
86 |
+
|
87 |
+
**Multiple losses**: If your model accumulates gradients from multiple losses,
|
88 |
+
this can be made more efficient by supplying ``update_master_grads=False``
|
89 |
+
to :attr:`backward`. See docstring for :attr:`backward`.
|
90 |
+
|
91 |
+
**Manually adjusting loss scale**: The current loss scale can be retrieved or set via ::
|
92 |
+
|
93 |
+
print(optimizer.loss_scale)
|
94 |
+
optimizer.loss_scale = new_loss_scale
|
95 |
+
|
96 |
+
For static loss scaling, manually adjusting the loss scale over time is a reasonable
|
97 |
+
thing to do. During later epochs, gradients may become smaller, and a
|
98 |
+
higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss
|
99 |
+
scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting
|
100 |
+
the loss scale is not recommended.
|
101 |
+
|
102 |
+
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
|
103 |
+
Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer`
|
104 |
+
should still work as intended.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self,
|
108 |
+
init_optimizer,
|
109 |
+
static_loss_scale=1.0,
|
110 |
+
dynamic_loss_scale=False,
|
111 |
+
dynamic_loss_args=None,
|
112 |
+
verbose=True):
|
113 |
+
if not torch.cuda.is_available:
|
114 |
+
raise SystemError("Cannot use fp16 without CUDA.")
|
115 |
+
|
116 |
+
self.verbose = verbose
|
117 |
+
|
118 |
+
self.optimizer = init_optimizer
|
119 |
+
# init_state_dict sets up an alternative way to cast per-param state tensors.
|
120 |
+
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
|
121 |
+
# init_state_dict = init_optimizer.state_dict()
|
122 |
+
|
123 |
+
self.fp16_groups = []
|
124 |
+
self.fp32_from_fp16_groups = []
|
125 |
+
self.fp32_from_fp32_groups = []
|
126 |
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
127 |
+
self.maybe_print("FP16_Optimizer processing param group {}:".format(i))
|
128 |
+
fp16_params_this_group = []
|
129 |
+
fp32_params_this_group = []
|
130 |
+
fp32_from_fp16_params_this_group = []
|
131 |
+
for i, param in enumerate(param_group['params']):
|
132 |
+
if param.requires_grad:
|
133 |
+
if param.type() == 'torch.cuda.HalfTensor':
|
134 |
+
self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
|
135 |
+
.format(param.size()))
|
136 |
+
fp16_params_this_group.append(param)
|
137 |
+
master_param = param.detach().clone().float()
|
138 |
+
master_param.requires_grad = True
|
139 |
+
param_group['params'][i] = master_param
|
140 |
+
fp32_from_fp16_params_this_group.append(master_param)
|
141 |
+
# Reset existing state dict key to the new master param.
|
142 |
+
# We still need to recast per-param state tensors, if any, to FP32.
|
143 |
+
if param in self.optimizer.state:
|
144 |
+
self.optimizer.state[master_param] = self.optimizer.state.pop(param)
|
145 |
+
elif param.type() == 'torch.cuda.FloatTensor':
|
146 |
+
self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
|
147 |
+
.format(param.size()))
|
148 |
+
fp32_params_this_group.append(param)
|
149 |
+
param_group['params'][i] = param
|
150 |
+
else:
|
151 |
+
raise TypeError("Wrapped parameters must be either "
|
152 |
+
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
|
153 |
+
"Received {}".format(param.type()))
|
154 |
+
|
155 |
+
self.fp16_groups.append(fp16_params_this_group)
|
156 |
+
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
|
157 |
+
self.fp32_from_fp32_groups.append(fp32_params_this_group)
|
158 |
+
|
159 |
+
self.all_fp16_params = []
|
160 |
+
for group in self.fp16_groups:
|
161 |
+
self.all_fp16_params += group
|
162 |
+
|
163 |
+
self.all_fp32_from_fp16_params = []
|
164 |
+
for group in self.fp32_from_fp16_groups:
|
165 |
+
self.all_fp32_from_fp16_params += group
|
166 |
+
|
167 |
+
self.all_fp32_from_fp32_params = []
|
168 |
+
for group in self.fp32_from_fp32_groups:
|
169 |
+
self.all_fp32_from_fp32_params += group
|
170 |
+
|
171 |
+
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
|
172 |
+
self.optimizer.load_state_dict(self.optimizer.state_dict())
|
173 |
+
# alternative way to cast per-param state tensors:
|
174 |
+
# self.optimizer.load_state_dict(init_state_dict)
|
175 |
+
|
176 |
+
if dynamic_loss_scale:
|
177 |
+
self.dynamic_loss_scale = True
|
178 |
+
if dynamic_loss_args is not None:
|
179 |
+
self.loss_scaler = LossScaler("dynamic", **dynamic_loss_args)
|
180 |
+
else:
|
181 |
+
self.loss_scaler = LossScaler("dynamic")
|
182 |
+
else:
|
183 |
+
self.dynamic_loss_scale = False
|
184 |
+
self.loss_scaler = LossScaler(static_loss_scale)
|
185 |
+
|
186 |
+
self.overflow = False
|
187 |
+
self.first_closure_call_this_step = True
|
188 |
+
|
189 |
+
self.clip_grad_norm = clip_grad_norm
|
190 |
+
|
191 |
+
# TODO: Centralize exposure and import error checking for the C backend.
|
192 |
+
if multi_tensor_applier.available:
|
193 |
+
import amp_C
|
194 |
+
self.multi_tensor_scale = amp_C.multi_tensor_scale
|
195 |
+
self._dummy_overflow_buf = torch.cuda.IntTensor([0]);
|
196 |
+
|
197 |
+
# Having self.maybe_print distinct from _amp_state.maybe_print is another artifact
|
198 |
+
# of having to support FP16_Optimizer separately, for the time being.
|
199 |
+
def maybe_print(self, msg):
|
200 |
+
if self.verbose:
|
201 |
+
print(msg)
|
202 |
+
|
203 |
+
def __getstate__(self):
|
204 |
+
raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
|
205 |
+
|
206 |
+
def __setstate__(self, state):
|
207 |
+
raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")
|
208 |
+
|
209 |
+
def zero_grad(self, set_grads_to_None=False):
|
210 |
+
"""
|
211 |
+
Zero fp32 and fp16 parameter grads.
|
212 |
+
"""
|
213 |
+
# In principle, only the .grad attributes of the model params need to be zeroed,
|
214 |
+
# because gradients are copied into the FP32 master params. However, we zero
|
215 |
+
# all gradients owned by the optimizer, just to be safe:
|
216 |
+
for group in self.optimizer.param_groups:
|
217 |
+
for p in group['params']:
|
218 |
+
if set_grads_to_None:
|
219 |
+
p.grad = None
|
220 |
+
else:
|
221 |
+
if p.grad is not None:
|
222 |
+
p.grad.detach_()
|
223 |
+
p.grad.zero_()
|
224 |
+
|
225 |
+
# Zero fp16 gradients owned by the model:
|
226 |
+
for fp16_group in self.fp16_groups:
|
227 |
+
for param in fp16_group:
|
228 |
+
if set_grads_to_None:
|
229 |
+
param.grad = None
|
230 |
+
else:
|
231 |
+
if param.grad is not None:
|
232 |
+
param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
|
233 |
+
param.grad.zero_()
|
234 |
+
|
235 |
+
# Should not be used anymore.
|
236 |
+
# def _check_overflow(self):
|
237 |
+
# params = []
|
238 |
+
# for group in self.fp16_groups:
|
239 |
+
# for param in group:
|
240 |
+
# params.append(param)
|
241 |
+
# for group in self.fp32_from_fp32_groups:
|
242 |
+
# for param in group:
|
243 |
+
# params.append(param)
|
244 |
+
# self.overflow = self.loss_scaler.has_overflow(params)
|
245 |
+
|
246 |
+
# def _update_scale(self, has_overflow=False):
|
247 |
+
# self.loss_scaler.update_scale(has_overflow)
|
248 |
+
|
249 |
+
def _master_params_to_model_params(self):
|
250 |
+
if multi_tensor_applier.available:
|
251 |
+
if len(self.all_fp16_params) > 0:
|
252 |
+
multi_tensor_applier(
|
253 |
+
self.multi_tensor_scale,
|
254 |
+
self._dummy_overflow_buf,
|
255 |
+
[self.all_fp32_from_fp16_params, self.all_fp16_params],
|
256 |
+
1.0)
|
257 |
+
else:
|
258 |
+
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
|
259 |
+
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
|
260 |
+
|
261 |
+
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
|
262 |
+
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
|
263 |
+
# def _model_grads_to_master_grads(self):
|
264 |
+
# for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
|
265 |
+
# model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
|
266 |
+
|
267 |
+
# def _downscale_master(self):
|
268 |
+
# if self.loss_scale != 1.0:
|
269 |
+
# for group in self.optimizer.param_groups:
|
270 |
+
# for param in group['params']:
|
271 |
+
# if param.grad is not None:
|
272 |
+
# param.grad.data.mul_(1./self.loss_scale)
|
273 |
+
|
274 |
+
def clip_master_grads(self, max_norm, norm_type=2):
|
275 |
+
"""
|
276 |
+
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
max_norm (float or int): max norm of the gradients
|
280 |
+
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
281 |
+
infinity norm.
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
Total norm of the current fp32 gradients (viewed as a single vector).
|
285 |
+
|
286 |
+
.. warning::
|
287 |
+
Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).
|
288 |
+
"""
|
289 |
+
if not self.overflow:
|
290 |
+
fp32_params = []
|
291 |
+
for param_group in self.optimizer.param_groups:
|
292 |
+
for param in param_group['params']:
|
293 |
+
fp32_params.append(param)
|
294 |
+
return self.clip_grad_norm(fp32_params, max_norm, norm_type)
|
295 |
+
else:
|
296 |
+
return -1
|
297 |
+
|
298 |
+
def state_dict(self):
|
299 |
+
"""
|
300 |
+
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
|
301 |
+
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
|
302 |
+
of the contained Pytorch optimizer.
|
303 |
+
Example::
|
304 |
+
|
305 |
+
checkpoint = {}
|
306 |
+
checkpoint['model'] = model.state_dict()
|
307 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
308 |
+
torch.save(checkpoint, "saved.pth")
|
309 |
+
"""
|
310 |
+
state_dict = {}
|
311 |
+
state_dict['loss_scaler'] = self.loss_scaler
|
312 |
+
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
|
313 |
+
state_dict['overflow'] = self.overflow
|
314 |
+
state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
|
315 |
+
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
|
316 |
+
state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups
|
317 |
+
return state_dict
|
318 |
+
|
319 |
+
def load_state_dict(self, state_dict):
|
320 |
+
"""
|
321 |
+
Loads a state_dict created by an earlier call to state_dict().
|
322 |
+
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
|
323 |
+
whose parameters in turn came from ``model``, it is expected that the user
|
324 |
+
will call ``model.load_state_dict()`` before
|
325 |
+
``fp16_optimizer_instance.load_state_dict()`` is called.
|
326 |
+
|
327 |
+
Example::
|
328 |
+
|
329 |
+
model = torch.nn.Linear(D_in, D_out).cuda().half()
|
330 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
331 |
+
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
|
332 |
+
...
|
333 |
+
checkpoint = torch.load("saved.pth")
|
334 |
+
model.load_state_dict(checkpoint['model'])
|
335 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
336 |
+
"""
|
337 |
+
# I think it should actually be ok to reload the optimizer before the model.
|
338 |
+
self.loss_scaler = state_dict['loss_scaler']
|
339 |
+
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
|
340 |
+
self.overflow = state_dict['overflow']
|
341 |
+
self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
|
342 |
+
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
343 |
+
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
|
344 |
+
# The optimizer's hyperparameters and internal buffers are also up to date.
|
345 |
+
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
|
346 |
+
# out of date. There are two options.
|
347 |
+
# 1: Refresh the master params from the model's fp16 params.
|
348 |
+
# This requires less storage but incurs precision loss.
|
349 |
+
# 2: Save and restore the fp32 master copies separately.
|
350 |
+
# We choose option 2.
|
351 |
+
#
|
352 |
+
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
|
353 |
+
# of their associated parameters, because it's possible those buffers might not exist yet in
|
354 |
+
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
|
355 |
+
# constructed in the same way as the one whose state_dict we are loading, the same master params
|
356 |
+
# are guaranteed to exist, so we can just copy_() from the saved master params.
|
357 |
+
for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
|
358 |
+
for current, saved in zip(current_group, saved_group):
|
359 |
+
current.data.copy_(saved.data)
|
360 |
+
|
361 |
+
def step(self, closure=None): # could add clip option.
|
362 |
+
"""
|
363 |
+
If no closure is supplied, :attr:`step` should be called after
|
364 |
+
``fp16_optimizer_obj.backward(loss)``.
|
365 |
+
:attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
|
366 |
+
:class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
|
367 |
+
originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
|
368 |
+
another forward pass using their model.
|
369 |
+
|
370 |
+
If a closure is supplied, :attr:`step` may be called without a prior call to
|
371 |
+
:attr:`backward(loss)`.
|
372 |
+
This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
|
373 |
+
However, the user should take care that any ``loss.backward()`` call within the closure
|
374 |
+
has been replaced by ``fp16_optimizer_obj.backward(loss)``.
|
375 |
+
|
376 |
+
Args:
|
377 |
+
closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss.
|
378 |
+
|
379 |
+
Example with closure::
|
380 |
+
|
381 |
+
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
|
382 |
+
# existing pytorch optimizer.
|
383 |
+
for input, target in dataset:
|
384 |
+
def closure():
|
385 |
+
optimizer.zero_grad()
|
386 |
+
output = model(input)
|
387 |
+
loss = loss_fn(output, target)
|
388 |
+
# loss.backward() becomes:
|
389 |
+
optimizer.backward(loss)
|
390 |
+
return loss
|
391 |
+
optimizer.step(closure)
|
392 |
+
|
393 |
+
.. warning::
|
394 |
+
Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.
|
395 |
+
|
396 |
+
.. _`ordinary Pytorch optimizer use`:
|
397 |
+
http://pytorch.org/docs/master/optim.html#optimizer-step-closure
|
398 |
+
"""
|
399 |
+
|
400 |
+
scale = self.loss_scaler.loss_scale()
|
401 |
+
# To consider: Should this be in step(), or update_master_grads? It works either way,
|
402 |
+
# but I should make it consistent with the Amp control flow, which updates the scale
|
403 |
+
# during backward context manager exit.
|
404 |
+
# self._update_scale(self.overflow)
|
405 |
+
|
406 |
+
if self.overflow:
|
407 |
+
# Using _amp_state.maybe_print instead of self.print here is intentional.
|
408 |
+
maybe_print("Gradient overflow. Skipping step, reducing " +
|
409 |
+
"loss scale to {}".format(self.loss_scaler.loss_scale()))
|
410 |
+
return
|
411 |
+
|
412 |
+
if closure is not None:
|
413 |
+
retval = self._step_with_closure(closure)
|
414 |
+
else:
|
415 |
+
# torch.cuda.nvtx.range_push("pytorch optimizer step")
|
416 |
+
retval = self.optimizer.step()
|
417 |
+
# torch.cuda.nvtx.range_pop()
|
418 |
+
|
419 |
+
self._master_params_to_model_params()
|
420 |
+
|
421 |
+
return retval
|
422 |
+
|
423 |
+
def _step_with_closure(self, closure):
|
424 |
+
def wrapped_closure():
|
425 |
+
# helpful for debugging
|
426 |
+
# print("Calling wrapped_closure, first_closure_call_this_step = {}"
|
427 |
+
# .format(self.first_closure_call_this_step))
|
428 |
+
if self.first_closure_call_this_step:
|
429 |
+
# We expect that the fp16 params are initially fresh on entering self.step(),
|
430 |
+
# so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
|
431 |
+
# is called within self.optimizer.step().
|
432 |
+
self.first_closure_call_this_step = False
|
433 |
+
else:
|
434 |
+
# If self.optimizer.step() internally calls wrapped_closure more than once,
|
435 |
+
# it may update the fp32 params after each call. However, self.optimizer
|
436 |
+
# doesn't know about the fp16 params at all. If the fp32 params get updated,
|
437 |
+
# we can't rely on self.optimizer to refresh the fp16 params. We need
|
438 |
+
# to handle that manually:
|
439 |
+
self._master_params_to_model_params()
|
440 |
+
# Our API expects the user to give us ownership of the backward() call by
|
441 |
+
# replacing all calls to loss.backward() with optimizer.backward(loss).
|
442 |
+
# This requirement holds whether or not the call to backward() is made within a closure.
|
443 |
+
# If the user is properly calling optimizer.backward(loss) within "closure,"
|
444 |
+
# calling closure() here will give the fp32 master params fresh gradients
|
445 |
+
# for the optimizer to play with, so all wrapped_closure needs to do is call
|
446 |
+
# closure() and return the loss.
|
447 |
+
temp_loss = closure()
|
448 |
+
while(self.overflow):
|
449 |
+
scale = self.loss_scaler.loss_scale()
|
450 |
+
# self._update_scale(self.overflow) # now done at the end of backward
|
451 |
+
print("OVERFLOW within closure! Skipping step, reducing loss scale to {}".format(
|
452 |
+
self.loss_scaler.loss_scale()))
|
453 |
+
temp_loss = closure()
|
454 |
+
return temp_loss
|
455 |
+
|
456 |
+
retval = self.optimizer.step(wrapped_closure)
|
457 |
+
|
458 |
+
self.first_closure_call_this_step = True
|
459 |
+
|
460 |
+
return retval
|
461 |
+
|
462 |
+
def backward(self, loss, update_master_grads=True, retain_graph=False):
|
463 |
+
"""
|
464 |
+
:attr:`backward` performs the following conceptual steps:
|
465 |
+
|
466 |
+
1. fp32_loss = loss.float() (see first Note below)
|
467 |
+
2. scaled_loss = fp32_loss*loss_scale
|
468 |
+
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).
|
469 |
+
4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32.
|
470 |
+
5. Finally, master grads are divided by loss_scale.
|
471 |
+
|
472 |
+
In this way, after :attr:`backward`, the master params have fresh gradients,
|
473 |
+
and :attr:`step` may be called.
|
474 |
+
|
475 |
+
.. note::
|
476 |
+
:attr:`backward` internally converts the loss to fp32 before applying the loss scale.
|
477 |
+
This provides some additional safety against overflow if the user has supplied an
|
478 |
+
fp16 loss value.
|
479 |
+
However, for maximum overflow safety, the user should
|
480 |
+
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
|
481 |
+
:attr:`backward`.
|
482 |
+
|
483 |
+
.. warning::
|
484 |
+
The gradients found in a model's leaves after the call to
|
485 |
+
:attr:`backward` should not be regarded as valid in general,
|
486 |
+
because it's possible
|
487 |
+
they have been scaled (and in the case of dynamic loss scaling,
|
488 |
+
the scale factor may change over time).
|
489 |
+
If the user wants to inspect gradients after a call to :attr:`backward`,
|
490 |
+
only the master gradients should be regarded as valid. These can be retrieved via
|
491 |
+
:attr:`inspect_master_grad_data()`.
|
492 |
+
|
493 |
+
Args:
|
494 |
+
loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
|
495 |
+
update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.
|
496 |
+
retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).
|
497 |
+
|
498 |
+
Example::
|
499 |
+
|
500 |
+
# Ordinary operation:
|
501 |
+
optimizer.backward(loss)
|
502 |
+
|
503 |
+
# Naive operation with multiple losses (technically valid, but less efficient):
|
504 |
+
# fp32 grads will be correct after the second call, but
|
505 |
+
# the first call incurs an unnecessary fp16->fp32 grad copy.
|
506 |
+
optimizer.backward(loss1)
|
507 |
+
optimizer.backward(loss2)
|
508 |
+
|
509 |
+
# More efficient way to handle multiple losses:
|
510 |
+
# The fp16->fp32 grad copy is delayed until fp16 grads from all
|
511 |
+
# losses have been accumulated.
|
512 |
+
optimizer.backward(loss1, update_master_grads=False)
|
513 |
+
optimizer.backward(loss2, update_master_grads=False)
|
514 |
+
optimizer.update_master_grads()
|
515 |
+
"""
|
516 |
+
# To consider: try multiple backward passes using retain_grad=True to find
|
517 |
+
# a loss scale that works. After you find a loss scale that works, do a final dummy
|
518 |
+
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid
|
519 |
+
# discarding the iteration, but probably wouldn't improve overall efficiency.
|
520 |
+
scaled_loss = loss.float()*self.loss_scaler.loss_scale()
|
521 |
+
scaled_loss.backward(retain_graph=retain_graph)
|
522 |
+
if update_master_grads:
|
523 |
+
self.update_master_grads()
|
524 |
+
|
525 |
+
def update_master_grads(self):
|
526 |
+
# torch.cuda.nvtx.range_push("update_master_grads")
|
527 |
+
"""
|
528 |
+
Copy the ``.grad`` attribute from stored references to fp16 parameters to
|
529 |
+
the ``.grad`` attribute of the fp32 master parameters that are directly
|
530 |
+
updated by the optimizer. :attr:`update_master_grads` only needs to be called if
|
531 |
+
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
|
532 |
+
"""
|
533 |
+
# if self.dynamic_loss_scale:
|
534 |
+
# self._check_overflow()
|
535 |
+
# if self.overflow: return
|
536 |
+
# self._model_grads_to_master_grads()
|
537 |
+
# self._downscale_master()
|
538 |
+
# Use the one-shot multi-tensor apply kernel
|
539 |
+
self.loss_scaler.clear_overflow_state()
|
540 |
+
if len(self.all_fp16_params) > 0:
|
541 |
+
# print("Model grads before")
|
542 |
+
# print([param.grad.data for param in self.all_fp16_params])
|
543 |
+
# I'm ONLY writing this as an incremental way to make some tests pass until
|
544 |
+
# I can refactor the tests as well.
|
545 |
+
# FP16_Optimizer should not be used by anyone.
|
546 |
+
model_grads = []
|
547 |
+
master_grads = []
|
548 |
+
for model_param, master_param in zip(self.all_fp16_params,
|
549 |
+
self.all_fp32_from_fp16_params):
|
550 |
+
if model_param.grad is not None:
|
551 |
+
model_grads.append(model_param.grad)
|
552 |
+
if master_param.grad is None:
|
553 |
+
master_param.grad = torch.empty_like(master_param)
|
554 |
+
master_grads.append(master_param.grad)
|
555 |
+
self.loss_scaler.unscale(
|
556 |
+
model_grads,
|
557 |
+
master_grads,
|
558 |
+
self.loss_scaler.loss_scale())
|
559 |
+
# print("Master grads after")
|
560 |
+
# print([param.grad.data for param in self.all_fp32_from_fp16_params])
|
561 |
+
if len(self.all_fp32_from_fp32_params) > 0:
|
562 |
+
model_grads = []
|
563 |
+
master_grads = []
|
564 |
+
for model_param, master_param in zip(self.all_fp32_from_fp32_params,
|
565 |
+
self.all_fp32_from_fp32_params):
|
566 |
+
if model_param.grad is not None:
|
567 |
+
model_grads.append(model_param.grad)
|
568 |
+
master_grads.append(master_param.grad)
|
569 |
+
# print("Model grads before")
|
570 |
+
# print([param.grad.data for param in self.all_fp32_from_fp32_params])
|
571 |
+
self.loss_scaler.unscale(
|
572 |
+
model_grads,
|
573 |
+
master_grads,
|
574 |
+
self.loss_scaler.loss_scale())
|
575 |
+
# print("Master grads after")
|
576 |
+
# print([param.grad.data for param in self.all_fp32_from_fp32_params])
|
577 |
+
# quit()
|
578 |
+
self.overflow = self.loss_scaler.update_scale()
|
579 |
+
# torch.cuda.nvtx.range_pop()
|
580 |
+
|
581 |
+
|
582 |
+
def inspect_master_grad_data(self):
|
583 |
+
"""
|
584 |
+
When running with :class:`FP16_Optimizer`,
|
585 |
+
``.grad`` attributes of a model's fp16 leaves should not be
|
586 |
+
regarded as truthful, because they might be scaled.
|
587 |
+
After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
|
588 |
+
the fp32 master params' ``.grad``
|
589 |
+
attributes will contain valid gradients properly divided by the loss scale. However,
|
590 |
+
because :class:`FP16_Optimizer` flattens some parameters, accessing them may be
|
591 |
+
nonintuitive. :attr:`inspect_master_grad_data`
|
592 |
+
allows those gradients to be viewed with shapes corresponding to their associated model leaves.
|
593 |
+
|
594 |
+
Returns:
|
595 |
+
List of lists (one list for each parameter group). The list for each parameter group
|
596 |
+
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
|
597 |
+
"""
|
598 |
+
if self.overflow:
|
599 |
+
print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. "
|
600 |
+
"Gradients are currently invalid (may be inf, nan, or stale). Returning None.")
|
601 |
+
return None
|
602 |
+
else:
|
603 |
+
# The optimizer owns only references to master params.
|
604 |
+
master_grads_data = []
|
605 |
+
for param_group in self.optimizer.param_groups:
|
606 |
+
master_grads_this_group = []
|
607 |
+
for param in param_group['params']:
|
608 |
+
if param.grad is not None:
|
609 |
+
master_grads_this_group.append(param.grad.data)
|
610 |
+
else:
|
611 |
+
master_grads_this_group.append(None)
|
612 |
+
master_grads_data.append(master_grads_this_group)
|
613 |
+
return master_grads_data
|
614 |
+
|
615 |
+
|
616 |
+
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
|
617 |
+
def _get_loss_scale(self):
|
618 |
+
return self.loss_scaler.loss_scale()
|
619 |
+
|
620 |
+
def _set_loss_scale(self, value):
|
621 |
+
self.loss_scaler._loss_scale = value
|
622 |
+
|
623 |
+
loss_scale = property(_get_loss_scale, _set_loss_scale)
|
624 |
+
|
625 |
+
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
|
626 |
+
def _get_state(self):
|
627 |
+
return self.optimizer.state
|
628 |
+
|
629 |
+
def _set_state(self, value):
|
630 |
+
self.optimizer.state = value
|
631 |
+
|
632 |
+
state = property(_get_state, _set_state)
|
633 |
+
|
634 |
+
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
|
635 |
+
# (for example, to adjust the learning rate)
|
636 |
+
def _get_param_groups(self):
|
637 |
+
return self.optimizer.param_groups
|
638 |
+
|
639 |
+
def _set_param_groups(self, value):
|
640 |
+
self.optimizer.param_groups = value
|
641 |
+
|
642 |
+
param_groups = property(_get_param_groups, _set_param_groups)
|
643 |
+
|
apex/apex/fp16_utils/fp16util.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
5 |
+
|
6 |
+
|
7 |
+
class tofp16(nn.Module):
|
8 |
+
"""
|
9 |
+
Utility module that implements::
|
10 |
+
|
11 |
+
def forward(self, input):
|
12 |
+
return input.half()
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super(tofp16, self).__init__()
|
17 |
+
|
18 |
+
def forward(self, input):
|
19 |
+
return input.half()
|
20 |
+
|
21 |
+
|
22 |
+
def BN_convert_float(module):
|
23 |
+
"""
|
24 |
+
Utility function for network_to_half().
|
25 |
+
|
26 |
+
Retained for legacy purposes.
|
27 |
+
"""
|
28 |
+
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
|
29 |
+
module.float()
|
30 |
+
for child in module.children():
|
31 |
+
BN_convert_float(child)
|
32 |
+
return module
|
33 |
+
|
34 |
+
|
35 |
+
def network_to_half(network):
|
36 |
+
"""
|
37 |
+
Convert model to half precision in a batchnorm-safe way.
|
38 |
+
|
39 |
+
Retained for legacy purposes. It is recommended to use FP16Model.
|
40 |
+
"""
|
41 |
+
return nn.Sequential(tofp16(), BN_convert_float(network.half()))
|
42 |
+
|
43 |
+
|
44 |
+
def convert_module(module, dtype):
|
45 |
+
"""
|
46 |
+
Converts a module's immediate parameters and buffers to dtype.
|
47 |
+
"""
|
48 |
+
for param in module.parameters(recurse=False):
|
49 |
+
if param is not None:
|
50 |
+
if param.data.dtype.is_floating_point:
|
51 |
+
param.data = param.data.to(dtype=dtype)
|
52 |
+
if param._grad is not None and param._grad.data.dtype.is_floating_point:
|
53 |
+
param._grad.data = param._grad.data.to(dtype=dtype)
|
54 |
+
|
55 |
+
for buf in module.buffers(recurse=False):
|
56 |
+
if buf is not None and buf.data.dtype.is_floating_point:
|
57 |
+
buf.data = buf.data.to(dtype=dtype)
|
58 |
+
|
59 |
+
|
60 |
+
def convert_network(network, dtype):
|
61 |
+
"""
|
62 |
+
Converts a network's parameters and buffers to dtype.
|
63 |
+
"""
|
64 |
+
for module in network.modules():
|
65 |
+
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
|
66 |
+
continue
|
67 |
+
convert_module(module, dtype)
|
68 |
+
if isinstance(module, torch.nn.RNNBase) or isinstance(module, torch.nn.modules.rnn.RNNBase):
|
69 |
+
module.flatten_parameters()
|
70 |
+
return network
|
71 |
+
|
72 |
+
|
73 |
+
class FP16Model(nn.Module):
|
74 |
+
"""
|
75 |
+
Convert model to half precision in a batchnorm-safe way.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, network):
|
79 |
+
super(FP16Model, self).__init__()
|
80 |
+
self.network = convert_network(network, dtype=torch.half)
|
81 |
+
|
82 |
+
def forward(self, *inputs):
|
83 |
+
inputs = tuple(t.half() for t in inputs)
|
84 |
+
return self.network(*inputs)
|
85 |
+
|
86 |
+
|
87 |
+
def backwards_debug_hook(grad):
|
88 |
+
raise RuntimeError("master_params recieved a gradient in the backward pass!")
|
89 |
+
|
90 |
+
def prep_param_lists(model, flat_master=False):
|
91 |
+
"""
|
92 |
+
Creates a list of FP32 master parameters for a given model, as in
|
93 |
+
`Training Neural Networks with Mixed Precision: Real Examples`_.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
model (torch.nn.Module): Existing Pytorch model
|
97 |
+
flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization.
|
98 |
+
Returns:
|
99 |
+
A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element.
|
100 |
+
|
101 |
+
Example::
|
102 |
+
|
103 |
+
model_params, master_params = prep_param_lists(model)
|
104 |
+
|
105 |
+
.. warning::
|
106 |
+
Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.
|
107 |
+
|
108 |
+
.. _`Training Neural Networks with Mixed Precision: Real Examples`:
|
109 |
+
http://on-demand.gputechconf.com/gtc/2018/video/S81012/
|
110 |
+
"""
|
111 |
+
model_params = [param for param in model.parameters() if param.requires_grad]
|
112 |
+
|
113 |
+
if flat_master:
|
114 |
+
# Give the user some more useful error messages
|
115 |
+
try:
|
116 |
+
# flatten_dense_tensors returns a contiguous flat array.
|
117 |
+
# http://pytorch.org/docs/master/_modules/torch/_utils.html
|
118 |
+
master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
|
119 |
+
except:
|
120 |
+
print("Error in prep_param_lists: model may contain a mixture of parameters "
|
121 |
+
"of different types. Use flat_master=False, or use F16_Optimizer.")
|
122 |
+
raise
|
123 |
+
master_params = torch.nn.Parameter(master_params)
|
124 |
+
master_params.requires_grad = True
|
125 |
+
# master_params.register_hook(backwards_debug_hook)
|
126 |
+
if master_params.grad is None:
|
127 |
+
master_params.grad = master_params.new(*master_params.size())
|
128 |
+
return model_params, [master_params]
|
129 |
+
else:
|
130 |
+
master_params = [param.clone().float().detach() for param in model_params]
|
131 |
+
for param in master_params:
|
132 |
+
param.requires_grad = True
|
133 |
+
return model_params, master_params
|
134 |
+
|
135 |
+
|
136 |
+
def model_grads_to_master_grads(model_params, master_params, flat_master=False):
|
137 |
+
"""
|
138 |
+
Copy model gradients to master gradients.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
model_params: List of model parameters created by :func:`prep_param_lists`.
|
142 |
+
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.
|
143 |
+
"""
|
144 |
+
if flat_master:
|
145 |
+
# The flattening may incur one more deep copy than is necessary.
|
146 |
+
master_params[0].grad.data.copy_(
|
147 |
+
_flatten_dense_tensors([p.grad.data for p in model_params]))
|
148 |
+
else:
|
149 |
+
for model, master in zip(model_params, master_params):
|
150 |
+
if model.grad is not None:
|
151 |
+
if master.grad is None:
|
152 |
+
master.grad = Variable(master.data.new(*master.data.size()))
|
153 |
+
master.grad.data.copy_(model.grad.data)
|
154 |
+
else:
|
155 |
+
master.grad = None
|
156 |
+
|
157 |
+
|
158 |
+
def master_params_to_model_params(model_params, master_params, flat_master=False):
|
159 |
+
"""
|
160 |
+
Copy master parameters to model parameters.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
model_params: List of model parameters created by :func:`prep_param_lists`.
|
164 |
+
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
|
165 |
+
"""
|
166 |
+
if flat_master:
|
167 |
+
for model, master in zip(model_params,
|
168 |
+
_unflatten_dense_tensors(master_params[0].data, model_params)):
|
169 |
+
model.data.copy_(master)
|
170 |
+
else:
|
171 |
+
for model, master in zip(model_params, master_params):
|
172 |
+
model.data.copy_(master.data)
|
173 |
+
|
174 |
+
# Backward compatibility fixes
|
175 |
+
|
176 |
+
def to_python_float(t):
|
177 |
+
if hasattr(t, 'item'):
|
178 |
+
return t.item()
|
179 |
+
else:
|
180 |
+
return t[0]
|
181 |
+
|
182 |
+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
183 |
+
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
184 |
+
if TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
|
185 |
+
clip_grad_norm = torch.nn.utils.clip_grad_norm
|
186 |
+
else:
|
187 |
+
clip_grad_norm = torch.nn.utils.clip_grad_norm_
|
apex/apex/fp16_utils/loss_scaler.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# item() is a recent addition, so this helps with backward compatibility.
|
4 |
+
def to_python_float(t):
|
5 |
+
if hasattr(t, 'item'):
|
6 |
+
return t.item()
|
7 |
+
else:
|
8 |
+
return t[0]
|
9 |
+
|
10 |
+
class LossScaler:
|
11 |
+
"""
|
12 |
+
Class that manages a static loss scale. This class is intended to interact with
|
13 |
+
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
|
14 |
+
|
15 |
+
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
|
16 |
+
:class:`FP16_Optimizer`'s constructor.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
scale (float, optional, default=1.0): The loss scale.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, scale=1):
|
23 |
+
self.cur_scale = scale
|
24 |
+
|
25 |
+
# `params` is a list / generator of torch.Variable
|
26 |
+
def has_overflow(self, params):
|
27 |
+
return False
|
28 |
+
|
29 |
+
# `x` is a torch.Tensor
|
30 |
+
def _has_inf_or_nan(x):
|
31 |
+
return False
|
32 |
+
|
33 |
+
def update_scale(self, overflow):
|
34 |
+
pass
|
35 |
+
|
36 |
+
@property
|
37 |
+
def loss_scale(self):
|
38 |
+
return self.cur_scale
|
39 |
+
|
40 |
+
def scale_gradient(self, module, grad_in, grad_out):
|
41 |
+
return tuple(self.loss_scale * g for g in grad_in)
|
42 |
+
|
43 |
+
def backward(self, loss, retain_graph=False):
|
44 |
+
scaled_loss = loss*self.loss_scale
|
45 |
+
scaled_loss.backward(retain_graph=retain_graph)
|
46 |
+
|
47 |
+
class DynamicLossScaler:
|
48 |
+
"""
|
49 |
+
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
|
50 |
+
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
|
51 |
+
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
|
52 |
+
operates, because the default options can be changed using the
|
53 |
+
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
|
54 |
+
|
55 |
+
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
|
56 |
+
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
|
57 |
+
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
|
58 |
+
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
|
59 |
+
occurred.
|
60 |
+
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
|
61 |
+
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
|
62 |
+
If a certain number of iterations occur without overflowing gradients detected,
|
63 |
+
:class:`DynamicLossScaler` increases the loss scale once more.
|
64 |
+
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
|
65 |
+
always using the highest loss scale possible without incurring overflow.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
|
69 |
+
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
|
70 |
+
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self,
|
74 |
+
init_scale=2**32,
|
75 |
+
scale_factor=2.,
|
76 |
+
scale_window=1000):
|
77 |
+
self.cur_scale = init_scale
|
78 |
+
self.cur_iter = 0
|
79 |
+
self.last_overflow_iter = -1
|
80 |
+
self.scale_factor = scale_factor
|
81 |
+
self.scale_window = scale_window
|
82 |
+
|
83 |
+
# `params` is a list / generator of torch.Variable
|
84 |
+
def has_overflow(self, params):
|
85 |
+
for p in params:
|
86 |
+
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
|
87 |
+
return True
|
88 |
+
|
89 |
+
return False
|
90 |
+
|
91 |
+
# `x` is a torch.Tensor
|
92 |
+
def _has_inf_or_nan(x):
|
93 |
+
try:
|
94 |
+
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
|
95 |
+
# Pytorch's .sum() creates a one-element tensor of the same type as x
|
96 |
+
# (which is true for some recent version of pytorch).
|
97 |
+
cpu_sum = float(x.float().sum())
|
98 |
+
# More efficient version that can be used if .sum() returns a Python scalar
|
99 |
+
# cpu_sum = float(x.sum())
|
100 |
+
except RuntimeError as instance:
|
101 |
+
# We want to check if inst is actually an overflow exception.
|
102 |
+
# RuntimeError could come from a different error.
|
103 |
+
# If so, we still want the exception to propagate.
|
104 |
+
if "value cannot be converted" not in instance.args[0]:
|
105 |
+
raise
|
106 |
+
return True
|
107 |
+
else:
|
108 |
+
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
109 |
+
return True
|
110 |
+
return False
|
111 |
+
|
112 |
+
# `overflow` is boolean indicating whether the gradient overflowed
|
113 |
+
def update_scale(self, overflow):
|
114 |
+
if overflow:
|
115 |
+
# self.cur_scale /= self.scale_factor
|
116 |
+
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
|
117 |
+
self.last_overflow_iter = self.cur_iter
|
118 |
+
else:
|
119 |
+
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
|
120 |
+
self.cur_scale *= self.scale_factor
|
121 |
+
self.cur_iter += 1
|
122 |
+
|
123 |
+
@property
|
124 |
+
def loss_scale(self):
|
125 |
+
return self.cur_scale
|
126 |
+
|
127 |
+
def scale_gradient(self, module, grad_in, grad_out):
|
128 |
+
return tuple(self.loss_scale * g for g in grad_in)
|
129 |
+
|
130 |
+
def backward(self, loss, retain_graph=False):
|
131 |
+
scaled_loss = loss*self.loss_scale
|
132 |
+
scaled_loss.backward(retain_graph=retain_graph)
|
133 |
+
|
134 |
+
##############################################################
|
135 |
+
# Example usage below here -- assuming it's in a separate file
|
136 |
+
##############################################################
|
137 |
+
"""
|
138 |
+
TO-DO separate out into an example.
|
139 |
+
if __name__ == "__main__":
|
140 |
+
import torch
|
141 |
+
from torch.autograd import Variable
|
142 |
+
from dynamic_loss_scaler import DynamicLossScaler
|
143 |
+
|
144 |
+
# N is batch size; D_in is input dimension;
|
145 |
+
# H is hidden dimension; D_out is output dimension.
|
146 |
+
N, D_in, H, D_out = 64, 1000, 100, 10
|
147 |
+
|
148 |
+
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
|
149 |
+
x = Variable(torch.randn(N, D_in), requires_grad=False)
|
150 |
+
y = Variable(torch.randn(N, D_out), requires_grad=False)
|
151 |
+
|
152 |
+
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
|
153 |
+
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
|
154 |
+
parameters = [w1, w2]
|
155 |
+
|
156 |
+
learning_rate = 1e-6
|
157 |
+
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
|
158 |
+
loss_scaler = DynamicLossScaler()
|
159 |
+
|
160 |
+
for t in range(500):
|
161 |
+
y_pred = x.mm(w1).clamp(min=0).mm(w2)
|
162 |
+
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
|
163 |
+
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
|
164 |
+
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
|
165 |
+
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
|
166 |
+
|
167 |
+
# Run backprop
|
168 |
+
optimizer.zero_grad()
|
169 |
+
loss.backward()
|
170 |
+
|
171 |
+
# Check for overflow
|
172 |
+
has_overflow = DynamicLossScaler.has_overflow(parameters)
|
173 |
+
|
174 |
+
# If no overflow, unscale grad and update as usual
|
175 |
+
if not has_overflow:
|
176 |
+
for param in parameters:
|
177 |
+
param.grad.data.mul_(1. / loss_scaler.loss_scale)
|
178 |
+
optimizer.step()
|
179 |
+
# Otherwise, don't do anything -- ie, skip iteration
|
180 |
+
else:
|
181 |
+
print('OVERFLOW!')
|
182 |
+
|
183 |
+
# Update loss scale for next iteration
|
184 |
+
loss_scaler.update_scale(has_overflow)
|
185 |
+
|
186 |
+
"""
|
apex/apex/multi_tensor_apply/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .multi_tensor_apply import MultiTensorApply
|
2 |
+
|
3 |
+
multi_tensor_applier = MultiTensorApply(2048*32)
|
4 |
+
|
apex/apex/multi_tensor_apply/multi_tensor_apply.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class MultiTensorApply(object):
|
4 |
+
available = False
|
5 |
+
warned = False
|
6 |
+
|
7 |
+
def __init__(self, chunk_size):
|
8 |
+
try:
|
9 |
+
import amp_C
|
10 |
+
MultiTensorApply.available = True
|
11 |
+
self.chunk_size = chunk_size
|
12 |
+
except ImportError as err:
|
13 |
+
MultiTensorApply.available = False
|
14 |
+
MultiTensorApply.import_err = err
|
15 |
+
|
16 |
+
def check_avail(self):
|
17 |
+
if MultiTensorApply.available == False:
|
18 |
+
raise RuntimeError(
|
19 |
+
"Attempted to call MultiTensorApply method, but MultiTensorApply "
|
20 |
+
"is not available, possibly because Apex was installed without "
|
21 |
+
"--cpp_ext --cuda_ext. Original import error message:",
|
22 |
+
MultiTensorApply.import_err)
|
23 |
+
|
24 |
+
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
|
25 |
+
self.check_avail()
|
26 |
+
|
27 |
+
return op(self.chunk_size,
|
28 |
+
noop_flag_buffer,
|
29 |
+
tensor_lists,
|
30 |
+
*args)
|
apex/apex/normalization/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .fused_layer_norm import FusedLayerNorm
|
apex/apex/normalization/fused_layer_norm.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numbers
|
4 |
+
from torch.nn.parameter import Parameter
|
5 |
+
from torch.nn import init
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import importlib
|
8 |
+
|
9 |
+
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
10 |
+
def __init__(self, normalized_shape, eps=1e-6):
|
11 |
+
global fused_layer_norm_cuda
|
12 |
+
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
|
13 |
+
|
14 |
+
self.normalized_shape = normalized_shape
|
15 |
+
self.eps = eps
|
16 |
+
|
17 |
+
def forward(self, input, weight, bias):
|
18 |
+
input_ = input.contiguous()
|
19 |
+
weight_ = weight.contiguous()
|
20 |
+
bias_ = bias.contiguous()
|
21 |
+
output, mean, invvar = fused_layer_norm_cuda.forward_affine(
|
22 |
+
input_, self.normalized_shape, weight_, bias_, self.eps)
|
23 |
+
self.save_for_backward(input_, weight_, bias_, mean, invvar)
|
24 |
+
return output
|
25 |
+
|
26 |
+
def backward(self, grad_output):
|
27 |
+
input_, weight_, bias_, mean, invvar = self.saved_tensors
|
28 |
+
grad_input = grad_weight = grad_bias = None
|
29 |
+
grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
|
30 |
+
grad_output.contiguous(), mean, invvar,
|
31 |
+
input_, self.normalized_shape,
|
32 |
+
weight_, bias_, self.eps)
|
33 |
+
return grad_input, grad_weight, grad_bias;
|
34 |
+
|
35 |
+
class FusedLayerNormFunction(torch.autograd.Function):
|
36 |
+
def __init__(self, normalized_shape, eps=1e-6):
|
37 |
+
global fused_layer_norm_cuda
|
38 |
+
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
|
39 |
+
self.normalized_shape = normalized_shape
|
40 |
+
self.eps = eps
|
41 |
+
|
42 |
+
def forward(self, input):
|
43 |
+
input_ = input.contiguous()
|
44 |
+
output, mean, invvar = fused_layer_norm_cuda.forward(
|
45 |
+
input_, self.normalized_shape, self.eps)
|
46 |
+
self.save_for_backward(input_, mean, invvar)
|
47 |
+
return output
|
48 |
+
|
49 |
+
def backward(self, grad_output):
|
50 |
+
input_, mean, invvar = self.saved_tensors
|
51 |
+
grad_input = None
|
52 |
+
grad_input = fused_layer_norm_cuda.backward(
|
53 |
+
grad_output.contiguous(), mean, invvar,
|
54 |
+
input_, self.normalized_shape,
|
55 |
+
self.eps)
|
56 |
+
return grad_input
|
57 |
+
|
58 |
+
def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
|
59 |
+
return FusedLayerNormAffineFunction(normalized_shape,eps)(input, weight, bias)
|
60 |
+
|
61 |
+
def fused_layer_norm(input, normalized_shape, eps=1e-6):
|
62 |
+
return FusedLayerNormFunction(normalized_shape,eps)(input)
|
63 |
+
|
64 |
+
class FusedLayerNorm(torch.nn.Module):
|
65 |
+
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
66 |
+
the paper `Layer Normalization`_ .
|
67 |
+
|
68 |
+
Currently only runs on cuda() tensors.
|
69 |
+
|
70 |
+
.. math::
|
71 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
72 |
+
|
73 |
+
The mean and standard-deviation are calculated separately over the last
|
74 |
+
certain number dimensions which have to be of the shape specified by
|
75 |
+
:attr:`normalized_shape`.
|
76 |
+
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
|
77 |
+
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
|
78 |
+
|
79 |
+
.. note::
|
80 |
+
Unlike Batch Normalization and Instance Normalization, which applies
|
81 |
+
scalar scale and bias for each entire channel/plane with the
|
82 |
+
:attr:`affine` option, Layer Normalization applies per-element scale and
|
83 |
+
bias with :attr:`elementwise_affine`.
|
84 |
+
|
85 |
+
This layer uses statistics computed from input data in both training and
|
86 |
+
evaluation modes.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
normalized_shape (int or list or torch.Size): input shape from an expected input
|
90 |
+
of size
|
91 |
+
|
92 |
+
.. math::
|
93 |
+
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
|
94 |
+
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
|
95 |
+
|
96 |
+
If a single integer is used, it is treated as a singleton list, and this module will
|
97 |
+
normalize over the last dimension which is expected to be of that specific size.
|
98 |
+
eps: a value added to the denominator for numerical stability. Default: 1e-5
|
99 |
+
elementwise_affine: a boolean value that when set to ``True``, this module
|
100 |
+
has learnable per-element affine parameters initialized to ones (for weights)
|
101 |
+
and zeros (for biases). Default: ``True``.
|
102 |
+
|
103 |
+
Shape:
|
104 |
+
- Input: :math:`(N, *)`
|
105 |
+
- Output: :math:`(N, *)` (same shape as input)
|
106 |
+
|
107 |
+
Examples::
|
108 |
+
|
109 |
+
>>> input = torch.randn(20, 5, 10, 10)
|
110 |
+
>>> # With Learnable Parameters
|
111 |
+
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
|
112 |
+
>>> # Without Learnable Parameters
|
113 |
+
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
|
114 |
+
>>> # Normalize over last two dimensions
|
115 |
+
>>> m = apex.normalization.FusedLayerNorm([10, 10])
|
116 |
+
>>> # Normalize over last dimension of size 10
|
117 |
+
>>> m = apex.normalization.FusedLayerNorm(10)
|
118 |
+
>>> # Activating the module
|
119 |
+
>>> output = m(input)
|
120 |
+
|
121 |
+
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
|
122 |
+
"""
|
123 |
+
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
|
124 |
+
super(FusedLayerNorm, self).__init__()
|
125 |
+
|
126 |
+
global fused_layer_norm_cuda
|
127 |
+
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
|
128 |
+
|
129 |
+
if isinstance(normalized_shape, numbers.Integral):
|
130 |
+
normalized_shape = (normalized_shape,)
|
131 |
+
self.normalized_shape = torch.Size(normalized_shape)
|
132 |
+
self.eps = eps
|
133 |
+
self.elementwise_affine = elementwise_affine
|
134 |
+
if self.elementwise_affine:
|
135 |
+
self.weight = Parameter(torch.Tensor(*normalized_shape))
|
136 |
+
self.bias = Parameter(torch.Tensor(*normalized_shape))
|
137 |
+
else:
|
138 |
+
self.register_parameter('weight', None)
|
139 |
+
self.register_parameter('bias', None)
|
140 |
+
self.reset_parameters()
|
141 |
+
|
142 |
+
def reset_parameters(self):
|
143 |
+
if self.elementwise_affine:
|
144 |
+
init.ones_(self.weight)
|
145 |
+
init.zeros_(self.bias)
|
146 |
+
|
147 |
+
def forward(self, input):
|
148 |
+
if not input.is_cuda:
|
149 |
+
return F.layer_norm(
|
150 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps)
|
151 |
+
if self.elementwise_affine:
|
152 |
+
return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
|
153 |
+
input, self.weight, self.bias)
|
154 |
+
else:
|
155 |
+
return FusedLayerNormFunction(self.normalized_shape,self.eps)(
|
156 |
+
input)
|
157 |
+
|
158 |
+
def extra_repr(self):
|
159 |
+
return '{normalized_shape}, eps={eps}, ' \
|
160 |
+
'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
|
apex/apex/optimizers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_adam import FusedAdam
|
2 |
+
from .fp16_optimizer import FP16_Optimizer
|
apex/apex/optimizers/fp16_optimizer.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
3 |
+
|
4 |
+
class FP16_Optimizer(object):
|
5 |
+
"""
|
6 |
+
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
|
7 |
+
Designed only to wrap apex.optimizers.FusedAdam.
|
8 |
+
Refer to apex.fp16_utils documents for more information.
|
9 |
+
|
10 |
+
Example::
|
11 |
+
|
12 |
+
model = torch.nn.Linear(D_in, D_out).cuda().half()
|
13 |
+
optimizer = apex.optimizers.FusedAdam(model.parameters())
|
14 |
+
# Name the FP16_Optimizer instance to replace the existing optimizer
|
15 |
+
# (recommended but not required):
|
16 |
+
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
|
17 |
+
...
|
18 |
+
# loss.backward() becomes:
|
19 |
+
optimizer.backward(loss)
|
20 |
+
...
|
21 |
+
|
22 |
+
Example with dynamic loss scaling::
|
23 |
+
|
24 |
+
...
|
25 |
+
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
|
26 |
+
# optional arg to control dynamic loss scaling behavior
|
27 |
+
# dynamic_loss_args={'scale_window' : 500})
|
28 |
+
# Usually, dynamic_loss_args is not necessary.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
init_optimizer,
|
33 |
+
static_loss_scale=1.0,
|
34 |
+
dynamic_loss_scale=False,
|
35 |
+
dynamic_loss_args=None,
|
36 |
+
verbose=True):
|
37 |
+
|
38 |
+
# The fused optimizer does all the work. We need this layer for two reason:
|
39 |
+
# 1. maintain same user API from apex.fp16_utils
|
40 |
+
# 2. keep common stuff here in case we need to add new fused optimizer later
|
41 |
+
|
42 |
+
# differences from apex.fp16_utils:
|
43 |
+
# - assume all model params in fp16
|
44 |
+
# - assume all params requires grad
|
45 |
+
# - flat by groups, not keeping state. TODO: remove state explicitly?
|
46 |
+
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
|
47 |
+
if not torch.cuda.is_available:
|
48 |
+
raise SystemError("Cannot use fp16 without CUDA.")
|
49 |
+
self.optimizer = init_optimizer
|
50 |
+
|
51 |
+
# param flattened by groups
|
52 |
+
self.fp16_groups = []
|
53 |
+
self.fp16_groups_flat = []
|
54 |
+
self.fp32_groups_flat = []
|
55 |
+
|
56 |
+
# loop to deal with groups
|
57 |
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
58 |
+
# push this group to list before modify
|
59 |
+
self.fp16_groups.append(param_group['params'])
|
60 |
+
# init fp16 weight buffer, flattened
|
61 |
+
self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
|
62 |
+
# set model fp16 weight to slices of flattened buffer
|
63 |
+
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
|
64 |
+
for p,q in zip(self.fp16_groups[i], updated_params):
|
65 |
+
p.data = q.data
|
66 |
+
# init master weight, flattened
|
67 |
+
self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
|
68 |
+
# modify optimizer of have flat master weight
|
69 |
+
self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
|
70 |
+
param_group['params'] = [self.fp32_groups_flat[i]]
|
71 |
+
|
72 |
+
# we may have a way of fusing dynamic scale. Do not support for now
|
73 |
+
if dynamic_loss_scale:
|
74 |
+
if dynamic_loss_args is not None:
|
75 |
+
raise SystemError("Do not support dynamic loss scale args for now.")
|
76 |
+
self.dynamic_loss_scale = True
|
77 |
+
self.cur_scale = 2**16
|
78 |
+
self.cur_iter = 0
|
79 |
+
self.last_overflow_iter = -1
|
80 |
+
self.scale_factor = 2
|
81 |
+
self.scale_window = 1000
|
82 |
+
else:
|
83 |
+
self.dynamic_loss_scale = False
|
84 |
+
self.cur_iter = 0
|
85 |
+
self.cur_scale = static_loss_scale
|
86 |
+
self.verbose = verbose
|
87 |
+
|
88 |
+
def zero_grad(self, set_grads_to_None=True):
|
89 |
+
"""
|
90 |
+
Zero FP16 parameter grads.
|
91 |
+
"""
|
92 |
+
# FP32 grad should never exist.
|
93 |
+
# For speed, set model fp16 grad to None by default
|
94 |
+
for group in self.fp16_groups:
|
95 |
+
for p in group:
|
96 |
+
if set_grads_to_None:
|
97 |
+
p.grad = None
|
98 |
+
else:
|
99 |
+
if p.grad is not None:
|
100 |
+
p.grad.detach_()
|
101 |
+
p.grad.zero_()
|
102 |
+
|
103 |
+
def _compute_grad_norm(self, fp16_grads_flat, norm_type=2):
|
104 |
+
"""
|
105 |
+
Compute fp16 grad norm for later clipping(fused with update).
|
106 |
+
Internal accumulated in fp32.
|
107 |
+
Also fused in NaN check. Possibly other reduction needed for grad.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
fp16_grads_flat (tensor): fp16 grad flattened
|
111 |
+
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
112 |
+
infinity norm.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
Total norm of the current fp16 gradients (viewed as a single vector).
|
116 |
+
Returns -1 if the most recently computed fp16 gradients overflowed
|
117 |
+
"""
|
118 |
+
# TODO: Not most efficient with copy to cpu and sync
|
119 |
+
# only support 2-norm now
|
120 |
+
# for torch version <= 1.0.1, torch.norm with dtype will fail and fall back to cast
|
121 |
+
try:
|
122 |
+
norm = float(torch.norm(fp16_grads_flat, 2.0, dtype=torch.float32))
|
123 |
+
except TypeError as err:
|
124 |
+
norm = float(torch.norm(fp16_grads_flat.float(), 2.0))
|
125 |
+
if norm == float('inf') or norm == -float('inf') or norm != norm:
|
126 |
+
return -1
|
127 |
+
else:
|
128 |
+
return norm
|
129 |
+
|
130 |
+
def step(self, closure=None):
|
131 |
+
"""
|
132 |
+
Not supporting closure.
|
133 |
+
"""
|
134 |
+
# First compute norm for all group so we know if there is overflow
|
135 |
+
grads_groups_flat = []
|
136 |
+
norm_groups = []
|
137 |
+
skip = False
|
138 |
+
for i, group in enumerate(self.fp16_groups):
|
139 |
+
grads_groups_flat.append(_flatten_dense_tensors([p.grad for p in group]))
|
140 |
+
norm_groups.append(self._compute_grad_norm(grads_groups_flat[i]))
|
141 |
+
if norm_groups[i] == -1: #TODO: early break
|
142 |
+
skip = True
|
143 |
+
|
144 |
+
if skip:
|
145 |
+
self._update_scale(skip)
|
146 |
+
return
|
147 |
+
|
148 |
+
# norm is in fact norm*cur_scale
|
149 |
+
self.optimizer.step(grads=[[g] for g in grads_groups_flat],
|
150 |
+
output_params=[[p] for p in self.fp16_groups_flat],
|
151 |
+
scale=self.cur_scale,
|
152 |
+
grad_norms=norm_groups)
|
153 |
+
|
154 |
+
# TODO: we probably don't need this? just to be safe
|
155 |
+
for i in range(len(norm_groups)):
|
156 |
+
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
|
157 |
+
for p,q in zip(self.fp16_groups[i], updated_params):
|
158 |
+
p.data = q.data
|
159 |
+
|
160 |
+
self._update_scale(False)
|
161 |
+
return
|
162 |
+
|
163 |
+
def backward(self, loss):
|
164 |
+
"""
|
165 |
+
:attr:`backward` performs the following steps:
|
166 |
+
|
167 |
+
1. fp32_loss = loss.float()
|
168 |
+
2. scaled_loss = fp32_loss*loss_scale
|
169 |
+
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
|
170 |
+
"""
|
171 |
+
scaled_loss = (loss.float()) * self.cur_scale
|
172 |
+
scaled_loss.backward()
|
173 |
+
|
174 |
+
def _update_scale(self, skip):
|
175 |
+
if self.dynamic_loss_scale:
|
176 |
+
if skip:
|
177 |
+
if self.verbose:
|
178 |
+
print("\nGrad overflow on iteration", self.cur_iter)
|
179 |
+
print("Using dynamic loss scale of", self.cur_scale)
|
180 |
+
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
|
181 |
+
self.last_overflow_iter = self.cur_iter
|
182 |
+
else:
|
183 |
+
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
|
184 |
+
self.cur_scale *= self.scale_factor
|
185 |
+
else:
|
186 |
+
if skip:
|
187 |
+
print("\nGrad overflow on iteration", self.cur_iter)
|
188 |
+
print("Using static loss scale of", self.cur_scale)
|
189 |
+
self.cur_iter +=1
|
190 |
+
return
|
191 |
+
|
192 |
+
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
|
193 |
+
def _get_state(self):
|
194 |
+
return self.optimizer.state
|
195 |
+
|
196 |
+
def _set_state(self, value):
|
197 |
+
self.optimizer.state = value
|
198 |
+
|
199 |
+
state = property(_get_state, _set_state)
|
200 |
+
|
201 |
+
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
|
202 |
+
# (for example, to adjust the learning rate)
|
203 |
+
def _get_param_groups(self):
|
204 |
+
return self.optimizer.param_groups
|
205 |
+
|
206 |
+
def _set_param_groups(self, value):
|
207 |
+
self.optimizer.param_groups = value
|
208 |
+
|
209 |
+
param_groups = property(_get_param_groups, _set_param_groups)
|
210 |
+
|
211 |
+
def state_dict(self):
|
212 |
+
"""
|
213 |
+
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
|
214 |
+
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
|
215 |
+
of the contained Pytorch optimizer.
|
216 |
+
Example::
|
217 |
+
checkpoint = {}
|
218 |
+
checkpoint['model'] = model.state_dict()
|
219 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
220 |
+
torch.save(checkpoint, "saved.pth")
|
221 |
+
"""
|
222 |
+
state_dict = {}
|
223 |
+
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
|
224 |
+
state_dict['cur_scale'] = self.cur_scale
|
225 |
+
state_dict['cur_iter'] = self.cur_iter
|
226 |
+
if state_dict['dynamic_loss_scale']:
|
227 |
+
state_dict['last_overflow_iter'] = self.last_overflow_iter
|
228 |
+
state_dict['scale_factor'] = self.scale_factor
|
229 |
+
state_dict['scale_window'] = self.scale_window
|
230 |
+
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
|
231 |
+
state_dict['fp32_groups_flat'] = self.fp32_groups_flat
|
232 |
+
return state_dict
|
233 |
+
|
234 |
+
def load_state_dict(self, state_dict):
|
235 |
+
"""
|
236 |
+
Loads a state_dict created by an earlier call to state_dict().
|
237 |
+
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
|
238 |
+
whose parameters in turn came from ``model``, it is expected that the user
|
239 |
+
will call ``model.load_state_dict()`` before
|
240 |
+
``fp16_optimizer_instance.load_state_dict()`` is called.
|
241 |
+
Example::
|
242 |
+
model = torch.nn.Linear(D_in, D_out).cuda().half()
|
243 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
244 |
+
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
|
245 |
+
...
|
246 |
+
checkpoint = torch.load("saved.pth")
|
247 |
+
model.load_state_dict(checkpoint['model'])
|
248 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
249 |
+
"""
|
250 |
+
# I think it should actually be ok to reload the optimizer before the model.
|
251 |
+
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
|
252 |
+
self.cur_scale = state_dict['cur_scale']
|
253 |
+
self.cur_iter = state_dict['cur_iter']
|
254 |
+
if state_dict['dynamic_loss_scale']:
|
255 |
+
self.last_overflow_iter = state_dict['last_overflow_iter']
|
256 |
+
self.scale_factor = state_dict['scale_factor']
|
257 |
+
self.scale_window = state_dict['scale_window']
|
258 |
+
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
259 |
+
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
|
260 |
+
# The optimizer's hyperparameters and internal buffers are also up to date.
|
261 |
+
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
|
262 |
+
# out of date. There are two options.
|
263 |
+
# 1: Refresh the master params from the model's fp16 params.
|
264 |
+
# This requires less storage but incurs precision loss.
|
265 |
+
# 2: Save and restore the fp32 master copies separately.
|
266 |
+
# We choose option 2.
|
267 |
+
#
|
268 |
+
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
|
269 |
+
# of their associated parameters, because it's possible those buffers might not exist yet in
|
270 |
+
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
|
271 |
+
# constructed in the same way as the one whose state_dict we are loading, the same master params
|
272 |
+
# are guaranteed to exist, so we can just copy_() from the saved master params.
|
273 |
+
for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
|
274 |
+
current.data.copy_(saved.data)
|
apex/apex/optimizers/fused_adam.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
import torch
|
3 |
+
import importlib
|
4 |
+
|
5 |
+
class FusedAdam(torch.optim.Optimizer):
|
6 |
+
|
7 |
+
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
|
8 |
+
``python setup.py install --cuda_ext --cpp_ext``.
|
9 |
+
|
10 |
+
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
|
11 |
+
|
12 |
+
Arguments:
|
13 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
14 |
+
parameter groups.
|
15 |
+
lr (float, optional): learning rate. (default: 1e-3)
|
16 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
17 |
+
running averages of gradient and its square. (default: (0.9, 0.999))
|
18 |
+
eps (float, optional): term added to the denominator to improve
|
19 |
+
numerical stability. (default: 1e-8)
|
20 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
21 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
22 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
23 |
+
(default: False) NOT SUPPORTED in FusedAdam!
|
24 |
+
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
|
25 |
+
adds eps to the bias-corrected second moment estimate before
|
26 |
+
evaluating square root instead of adding it to the square root of
|
27 |
+
second moment estimate as in the original paper. (default: False)
|
28 |
+
|
29 |
+
.. _Adam\: A Method for Stochastic Optimization:
|
30 |
+
https://arxiv.org/abs/1412.6980
|
31 |
+
.. _On the Convergence of Adam and Beyond:
|
32 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, params,
|
36 |
+
lr=1e-3, bias_correction = True,
|
37 |
+
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
|
38 |
+
weight_decay=0., max_grad_norm=0., amsgrad=False):
|
39 |
+
global fused_adam_cuda
|
40 |
+
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
|
41 |
+
|
42 |
+
if amsgrad:
|
43 |
+
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
|
44 |
+
defaults = dict(lr=lr, bias_correction=bias_correction,
|
45 |
+
betas=betas, eps=eps, weight_decay=weight_decay,
|
46 |
+
max_grad_norm=max_grad_norm)
|
47 |
+
super(FusedAdam, self).__init__(params, defaults)
|
48 |
+
self.eps_mode = 0 if eps_inside_sqrt else 1
|
49 |
+
|
50 |
+
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
|
51 |
+
"""Performs a single optimization step.
|
52 |
+
|
53 |
+
Arguments:
|
54 |
+
closure (callable, optional): A closure that reevaluates the model
|
55 |
+
and returns the loss.
|
56 |
+
grads (list of tensors, optional): weight gradient to use for the
|
57 |
+
optimizer update. If gradients have type torch.half, parameters
|
58 |
+
are expected to be in type torch.float. (default: None)
|
59 |
+
output params (list of tensors, optional): A reduced precision copy
|
60 |
+
of the updated weights written out in addition to the regular
|
61 |
+
updated weights. Have to be of same type as gradients. (default: None)
|
62 |
+
scale (float, optional): factor to divide gradient tensor values
|
63 |
+
by before applying to weights. (default: 1)
|
64 |
+
"""
|
65 |
+
loss = None
|
66 |
+
if closure is not None:
|
67 |
+
loss = closure()
|
68 |
+
|
69 |
+
if grads is None:
|
70 |
+
grads_group = [None]*len(self.param_groups)
|
71 |
+
# backward compatibility
|
72 |
+
# assuming a list/generator of parameter means single group
|
73 |
+
elif isinstance(grads, types.GeneratorType):
|
74 |
+
grads_group = [grads]
|
75 |
+
elif type(grads[0])!=list:
|
76 |
+
grads_group = [grads]
|
77 |
+
else:
|
78 |
+
grads_group = grads
|
79 |
+
|
80 |
+
if output_params is None:
|
81 |
+
output_params_group = [None]*len(self.param_groups)
|
82 |
+
elif isinstance(output_params, types.GeneratorType):
|
83 |
+
output_params_group = [output_params]
|
84 |
+
elif type(output_params[0])!=list:
|
85 |
+
output_params_group = [output_params]
|
86 |
+
else:
|
87 |
+
output_params_group = output_params
|
88 |
+
|
89 |
+
if grad_norms is None:
|
90 |
+
grad_norms = [None]*len(self.param_groups)
|
91 |
+
|
92 |
+
for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms):
|
93 |
+
if grads_this_group is None:
|
94 |
+
grads_this_group = [None]*len(group['params'])
|
95 |
+
if output_params_this_group is None:
|
96 |
+
output_params_this_group = [None]*len(group['params'])
|
97 |
+
|
98 |
+
# compute combined scale factor for this group
|
99 |
+
combined_scale = scale
|
100 |
+
if group['max_grad_norm'] > 0:
|
101 |
+
# norm is in fact norm*scale
|
102 |
+
clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
|
103 |
+
if clip > 1:
|
104 |
+
combined_scale = clip * scale
|
105 |
+
|
106 |
+
bias_correction = 1 if group['bias_correction'] else 0
|
107 |
+
|
108 |
+
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
|
109 |
+
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
|
110 |
+
if p.grad is None and grad is None:
|
111 |
+
continue
|
112 |
+
if grad is None:
|
113 |
+
grad = p.grad.data
|
114 |
+
if grad.is_sparse:
|
115 |
+
raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
|
116 |
+
|
117 |
+
state = self.state[p]
|
118 |
+
|
119 |
+
# State initialization
|
120 |
+
if len(state) == 0:
|
121 |
+
state['step'] = 0
|
122 |
+
# Exponential moving average of gradient values
|
123 |
+
state['exp_avg'] = torch.zeros_like(p.data)
|
124 |
+
# Exponential moving average of squared gradient values
|
125 |
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
126 |
+
|
127 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
128 |
+
beta1, beta2 = group['betas']
|
129 |
+
|
130 |
+
state['step'] += 1
|
131 |
+
|
132 |
+
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
|
133 |
+
fused_adam_cuda.adam(p.data,
|
134 |
+
out_p,
|
135 |
+
exp_avg,
|
136 |
+
exp_avg_sq,
|
137 |
+
grad,
|
138 |
+
group['lr'],
|
139 |
+
beta1,
|
140 |
+
beta2,
|
141 |
+
group['eps'],
|
142 |
+
combined_scale,
|
143 |
+
state['step'],
|
144 |
+
self.eps_mode,
|
145 |
+
bias_correction,
|
146 |
+
group['weight_decay'])
|
147 |
+
return loss
|
apex/apex/parallel/LARC.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from torch.nn.parameter import Parameter
|
5 |
+
|
6 |
+
class LARC(object):
|
7 |
+
"""
|
8 |
+
:class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC,
|
9 |
+
in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
|
10 |
+
local learning rate for each individual parameter. The algorithm is designed to improve
|
11 |
+
convergence of large batch training.
|
12 |
+
|
13 |
+
See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
|
14 |
+
|
15 |
+
In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
|
16 |
+
of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
|
17 |
+
|
18 |
+
```
|
19 |
+
model = ...
|
20 |
+
optim = torch.optim.Adam(model.parameters(), lr=...)
|
21 |
+
optim = LARC(optim)
|
22 |
+
```
|
23 |
+
|
24 |
+
It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
|
25 |
+
|
26 |
+
```
|
27 |
+
model = ...
|
28 |
+
optim = torch.optim.Adam(model.parameters(), lr=...)
|
29 |
+
optim = LARC(optim)
|
30 |
+
optim = apex.fp16_utils.FP16_Optimizer(optim)
|
31 |
+
```
|
32 |
+
|
33 |
+
Args:
|
34 |
+
optimizer: Pytorch optimizer to wrap and modify learning rate for.
|
35 |
+
trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
|
36 |
+
clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
|
37 |
+
eps: epsilon kludge to help with numerical stability while calculating adaptive_lr
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
|
41 |
+
self.param_groups = optimizer.param_groups
|
42 |
+
self.optim = optimizer
|
43 |
+
self.trust_coefficient = trust_coefficient
|
44 |
+
self.eps = eps
|
45 |
+
self.clip = clip
|
46 |
+
|
47 |
+
def __getstate__(self):
|
48 |
+
return self.optim.__getstate__()
|
49 |
+
|
50 |
+
def __setstate__(self, state):
|
51 |
+
self.optim.__setstate__(state)
|
52 |
+
|
53 |
+
def __repr__(self):
|
54 |
+
return self.optim.__repr__()
|
55 |
+
|
56 |
+
def state_dict(self):
|
57 |
+
return self.optim.state_dict()
|
58 |
+
|
59 |
+
def load_state_dict(self, state_dict):
|
60 |
+
self.optim.load_state_dict(state_dict)
|
61 |
+
|
62 |
+
def zero_grad(self):
|
63 |
+
self.optim.zero_grad()
|
64 |
+
|
65 |
+
def add_param_group(self, param_group):
|
66 |
+
self.optim.add_param_group( param_group)
|
67 |
+
|
68 |
+
def step(self):
|
69 |
+
with torch.no_grad():
|
70 |
+
weight_decays = []
|
71 |
+
for group in self.optim.param_groups:
|
72 |
+
# absorb weight decay control from optimizer
|
73 |
+
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
|
74 |
+
weight_decays.append(weight_decay)
|
75 |
+
group['weight_decay'] = 0
|
76 |
+
for p in group['params']:
|
77 |
+
if p.grad is None:
|
78 |
+
continue
|
79 |
+
param_norm = torch.norm(p.data)
|
80 |
+
grad_norm = torch.norm(p.grad.data)
|
81 |
+
|
82 |
+
if param_norm != 0 and grad_norm != 0:
|
83 |
+
# calculate adaptive lr + weight decay
|
84 |
+
adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)
|
85 |
+
|
86 |
+
# clip learning rate for LARC
|
87 |
+
if self.clip:
|
88 |
+
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
|
89 |
+
adaptive_lr = min(adaptive_lr/group['lr'], 1)
|
90 |
+
|
91 |
+
p.grad.data += weight_decay * p.data
|
92 |
+
p.grad.data *= adaptive_lr
|
93 |
+
|
94 |
+
self.optim.step()
|
95 |
+
# return weight decay control to optimizer
|
96 |
+
for i, group in enumerate(self.optim.param_groups):
|
97 |
+
group['weight_decay'] = weight_decays[i]
|
apex/apex/parallel/README.md
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Distributed Data Parallel
|
2 |
+
|
3 |
+
distributed.py contains the source code for `apex.parallel.DistributedDataParallel`, a module wrapper that enables multi-process multi-GPU data parallel training optimized for NVIDIA's NCCL communication library.
|
4 |
+
|
5 |
+
`apex.parallel.DistributedDataParallel` achieves high performance by overlapping communication with
|
6 |
+
computation in the backward pass and bucketing smaller transfers to reduce the total number of
|
7 |
+
transfers required.
|
8 |
+
|
9 |
+
multiproc.py contains the source code for `apex.parallel.multiproc`, a launch utility that places one process on each of the node's available GPUs.
|
10 |
+
|
11 |
+
#### [API Documentation](https://nvidia.github.io/apex/parallel.html)
|
12 |
+
|
13 |
+
#### [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed)
|
14 |
+
|
15 |
+
#### [Imagenet example with Mixed Precision](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
|
16 |
+
|
17 |
+
#### [Simple example with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple/distributed_apex)
|
18 |
+
|
19 |
+
### Synchronized Batch Normalization
|
20 |
+
|
21 |
+
`apex.parallel.SyncBatchNorm` has similar APIs as with `torch.nn.BatchNorm*N*d`.
|
22 |
+
It reduces stats on the first (channel) dimension of the Tensor and accepts
|
23 |
+
arbitrary spatial dimensions.
|
24 |
+
|
25 |
+
#### Installation
|
26 |
+
|
27 |
+
Apex provides two sync BN implementation:
|
28 |
+
|
29 |
+
1. There is the Python-only implementation, which is the default implementation
|
30 |
+
when install with `python setup.py install`.
|
31 |
+
It uses PyTorch primitive operations and distributed communication package from
|
32 |
+
`torch.distributed`.
|
33 |
+
|
34 |
+
- _Python-only implementation requires input tensor to be of same data type as
|
35 |
+
layer_
|
36 |
+
|
37 |
+
2. We also provide implementation with kernels through CUDA/C++ extension with
|
38 |
+
improved performance. We are experimenting with Welford and Kahan for reduction
|
39 |
+
hoping to get better accuracy.
|
40 |
+
To use the kernel implementation, user need to install Apex with CUDA extension
|
41 |
+
enabled `python setup.py install --cuda_ext`.
|
42 |
+
|
43 |
+
- _Custom kernel implementation supports fp16 input with fp32 layer as cudnn.
|
44 |
+
This is required to run imagenet example in fp16._
|
45 |
+
|
46 |
+
- _Currently kernel implementation only supports GPU._
|
47 |
+
|
48 |
+
#### HowTo
|
49 |
+
|
50 |
+
1. User could use `apex.parallel.SyncBatchNorm` by building their module with
|
51 |
+
the layer explicitly.
|
52 |
+
|
53 |
+
```
|
54 |
+
import apex
|
55 |
+
input_t = torch.randn(3, 5, 20).cuda()
|
56 |
+
sbn = apex.parallel.SyncBatchNorm(5).cuda()
|
57 |
+
output_t = sbn(input)
|
58 |
+
```
|
59 |
+
|
60 |
+
2. User could also take a constructed `torch.nn.Model` and replace all its `torch.nn.BatchNorm*N*d` modules with `apex.parallel.SyncBatchNorm` through utility function `apex.parallel.convert_syncbn_model`.
|
61 |
+
|
62 |
+
```
|
63 |
+
# model is an instance of torch.nn.Module
|
64 |
+
import apex
|
65 |
+
sync_bn_model = apex.parallel.convert_syncbn_model(model)
|
66 |
+
```
|