initial commit from cc-ai/climateGAN
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .gitignore +143 -0
- Contributing.md +56 -0
- LICENSE +674 -0
- README.md +174 -3
- USAGE.md +328 -0
- apply_events.py +642 -0
- climategan/__init__.py +9 -0
- climategan/blocks.py +398 -0
- climategan/bn_fusion.py +137 -0
- climategan/data.py +539 -0
- climategan/deeplab/__init__.py +101 -0
- climategan/deeplab/deeplab_v2.py +198 -0
- climategan/deeplab/deeplab_v3.py +271 -0
- climategan/deeplab/mobilenet_v3.py +324 -0
- climategan/deeplab/resnet101_v3.py +203 -0
- climategan/deeplab/resnetmulti_v2.py +136 -0
- climategan/depth.py +230 -0
- climategan/discriminator.py +361 -0
- climategan/eval_metrics.py +635 -0
- climategan/fid.py +561 -0
- climategan/fire.py +133 -0
- climategan/generator.py +411 -0
- climategan/logger.py +445 -0
- climategan/losses.py +620 -0
- climategan/masker.py +234 -0
- climategan/norms.py +186 -0
- climategan/optim.py +291 -0
- climategan/painter.py +171 -0
- climategan/strings.py +99 -0
- climategan/trainer.py +1939 -0
- climategan/transforms.py +626 -0
- climategan/tutils.py +721 -0
- climategan/utils.py +1063 -0
- eval_masker.py +796 -0
- figures/ablation_comparison.py +394 -0
- figures/bootstrap_ablation.py +562 -0
- figures/bootstrap_ablation_summary.py +361 -0
- figures/human_evaluation.py +208 -0
- figures/labels.py +200 -0
- figures/metrics.py +676 -0
- figures/metrics_onefig.py +772 -0
- requirements-3.8.2.txt +91 -0
- requirements-any.txt +20 -0
- sbatch.py +933 -0
- shared/experiment/showcase.yaml +71 -0
- shared/template/mila_victor.sh +24 -0
- shared/template/resume_mila_victor.sh +24 -0
- shared/trainer/config.yaml +16 -0
- shared/trainer/defaults.yaml +334 -0
.gitattributes
CHANGED
@@ -31,3 +31,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
31 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
31 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
34 |
+
images/flood.png filter=lfs diff=lfs merge=lfs -text
|
35 |
+
images/smog.png filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/wildfire.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omnienv/
|
2 |
+
example_data/
|
3 |
+
.vscode/
|
4 |
+
.comet.config
|
5 |
+
.DS_Store
|
6 |
+
config/
|
7 |
+
tests/not_committed/
|
8 |
+
*.hydra
|
9 |
+
outputs/
|
10 |
+
eval_folder*
|
11 |
+
|
12 |
+
# Byte-compiled / optimized / DLL files
|
13 |
+
__pycache__/
|
14 |
+
*.py[cod]
|
15 |
+
*$py.class
|
16 |
+
|
17 |
+
# C extensions
|
18 |
+
*.so
|
19 |
+
|
20 |
+
# Distribution / packaging
|
21 |
+
.Python
|
22 |
+
build/
|
23 |
+
develop-eggs/
|
24 |
+
dist/
|
25 |
+
downloads/
|
26 |
+
eggs/
|
27 |
+
.eggs/
|
28 |
+
lib/
|
29 |
+
lib64/
|
30 |
+
parts/
|
31 |
+
sdist/
|
32 |
+
var/
|
33 |
+
wheels/
|
34 |
+
pip-wheel-metadata/
|
35 |
+
share/python-wheels/
|
36 |
+
*.egg-info/
|
37 |
+
.installed.cfg
|
38 |
+
*.egg
|
39 |
+
MANIFEST
|
40 |
+
|
41 |
+
# PyInstaller
|
42 |
+
# Usually these files are written by a python script from a template
|
43 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
44 |
+
*.manifest
|
45 |
+
*.spec
|
46 |
+
|
47 |
+
# Installer logs
|
48 |
+
pip-log.txt
|
49 |
+
pip-delete-this-directory.txt
|
50 |
+
|
51 |
+
# Unit test / coverage reports
|
52 |
+
htmlcov/
|
53 |
+
.tox/
|
54 |
+
.nox/
|
55 |
+
.coverage
|
56 |
+
.coverage.*
|
57 |
+
.cache
|
58 |
+
nosetests.xml
|
59 |
+
coverage.xml
|
60 |
+
*.cover
|
61 |
+
*.py,cover
|
62 |
+
.hypothesis/
|
63 |
+
.pytest_cache/
|
64 |
+
|
65 |
+
# Translations
|
66 |
+
*.mo
|
67 |
+
*.pot
|
68 |
+
|
69 |
+
# Django stuff:
|
70 |
+
*.log
|
71 |
+
local_settings.py
|
72 |
+
db.sqlite3
|
73 |
+
db.sqlite3-journal
|
74 |
+
|
75 |
+
# Flask stuff:
|
76 |
+
instance/
|
77 |
+
.webassets-cache
|
78 |
+
|
79 |
+
# Scrapy stuff:
|
80 |
+
.scrapy
|
81 |
+
|
82 |
+
# Sphinx documentation
|
83 |
+
docs/_build/
|
84 |
+
|
85 |
+
# PyBuilder
|
86 |
+
target/
|
87 |
+
|
88 |
+
# Jupyter Notebook
|
89 |
+
.ipynb_checkpoints
|
90 |
+
|
91 |
+
# IPython
|
92 |
+
profile_default/
|
93 |
+
ipython_config.py
|
94 |
+
|
95 |
+
# pyenv
|
96 |
+
.python-version
|
97 |
+
|
98 |
+
# pipenv
|
99 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
100 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
101 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
102 |
+
# install all needed dependencies.
|
103 |
+
#Pipfile.lock
|
104 |
+
|
105 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
106 |
+
__pypackages__/
|
107 |
+
|
108 |
+
# Celery stuff
|
109 |
+
celerybeat-schedule
|
110 |
+
celerybeat.pid
|
111 |
+
|
112 |
+
# SageMath parsed files
|
113 |
+
*.sage.py
|
114 |
+
|
115 |
+
# Environments
|
116 |
+
.env
|
117 |
+
.venv
|
118 |
+
env/
|
119 |
+
venv/
|
120 |
+
ENV/
|
121 |
+
env.bak/
|
122 |
+
venv.bak/
|
123 |
+
|
124 |
+
# Spyder project settings
|
125 |
+
.spyderproject
|
126 |
+
.spyproject
|
127 |
+
|
128 |
+
# Rope project settings
|
129 |
+
.ropeproject
|
130 |
+
|
131 |
+
# mkdocs documentation
|
132 |
+
/site
|
133 |
+
|
134 |
+
# mypy
|
135 |
+
.mypy_cache/
|
136 |
+
.dmypy.json
|
137 |
+
dmypy.json
|
138 |
+
|
139 |
+
# Pyre type checker
|
140 |
+
.pyre/
|
141 |
+
|
142 |
+
# local visualize tool
|
143 |
+
visualizedEval.py
|
Contributing.md
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1. Understand the file structure:
|
2 |
+
1. architectures in `discriminator.py` `generator.py` `classifier.py`
|
3 |
+
2. data loading in `data.py`
|
4 |
+
3. data transformation `transforms.py`
|
5 |
+
4. optimizers in `optim.py`
|
6 |
+
5. utilities in `utils.py`
|
7 |
+
6. training procedure in `trainer.py`
|
8 |
+
2. Write **tests** in `tests/`
|
9 |
+
1. your file should match `test_*.py`
|
10 |
+
2. update existing tests when adding functionalities
|
11 |
+
3. run tests regularly to check you haven't broken anything `python tests/run.py`
|
12 |
+
3. Add **WIP** in your PR's title when not ready to merge
|
13 |
+
5. Open an Issue if something's odd, or to assign yourself a todo
|
14 |
+
6. **Format your code** with [black](https://github.com/psf/black)
|
15 |
+
7. Only update `trainer/defaults.yaml` with values that should be shared across runs and users
|
16 |
+
1. use `config/trainer/local_tests.yaml` or any other to setup your particular config overriding `trainer/defaults.yaml`
|
17 |
+
|
18 |
+
## Running tests
|
19 |
+
|
20 |
+
As per `7.` you should set your particular config in `config/local_tests.yaml`. Mine looks like:
|
21 |
+
|
22 |
+
```yaml
|
23 |
+
output_path: /Users/victor/Documents/ccai/github/climategan/example_data
|
24 |
+
# -------------------
|
25 |
+
# ----- Tasks -----
|
26 |
+
# -------------------
|
27 |
+
#tasks: [a, d, h, s, t, w]
|
28 |
+
tasks: [a, d, s, t] # for now no h or w
|
29 |
+
# ----------------
|
30 |
+
# ----- Data -----
|
31 |
+
# ----------------
|
32 |
+
data:
|
33 |
+
files: # if one is not none it will override the dirs location
|
34 |
+
base: /Users/victor/Documents/ccai/github/climategan/example_data
|
35 |
+
transforms:
|
36 |
+
- name: hflip
|
37 |
+
ignore: false
|
38 |
+
p: 0.5
|
39 |
+
- name: resize
|
40 |
+
ignore: false
|
41 |
+
new_size: 256
|
42 |
+
- name: crop
|
43 |
+
ignore: false
|
44 |
+
height: 64
|
45 |
+
width: 64
|
46 |
+
gen:
|
47 |
+
encoder:
|
48 |
+
n_res: 1
|
49 |
+
default:
|
50 |
+
n_res: 1
|
51 |
+
|
52 |
+
train:
|
53 |
+
log_level: 1
|
54 |
+
```
|
55 |
+
|
56 |
+
Setting `n_res` to 1 is important to run tests faster and with less memory
|
LICENSE
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 29 June 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works.
|
12 |
+
|
13 |
+
The licenses for most software and other practical works are designed
|
14 |
+
to take away your freedom to share and change the works. By contrast,
|
15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
16 |
+
share and change all versions of a program--to make sure it remains free
|
17 |
+
software for all its users. We, the Free Software Foundation, use the
|
18 |
+
GNU General Public License for most of our software; it applies also to
|
19 |
+
any other work released this way by its authors. You can apply it to
|
20 |
+
your programs, too.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
To protect your rights, we need to prevent others from denying you
|
30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
31 |
+
certain responsibilities if you distribute copies of the software, or if
|
32 |
+
you modify it: responsibilities to respect the freedom of others.
|
33 |
+
|
34 |
+
For example, if you distribute copies of such a program, whether
|
35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
36 |
+
freedoms that you received. You must make sure that they, too, receive
|
37 |
+
or can get the source code. And you must show them these terms so they
|
38 |
+
know their rights.
|
39 |
+
|
40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
43 |
+
|
44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
45 |
+
that there is no warranty for this free software. For both users' and
|
46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
47 |
+
changed, so that their problems will not be attributed erroneously to
|
48 |
+
authors of previous versions.
|
49 |
+
|
50 |
+
Some devices are designed to deny users access to install or run
|
51 |
+
modified versions of the software inside them, although the manufacturer
|
52 |
+
can do so. This is fundamentally incompatible with the aim of
|
53 |
+
protecting users' freedom to change the software. The systematic
|
54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
56 |
+
have designed this version of the GPL to prohibit the practice for those
|
57 |
+
products. If such problems arise substantially in other domains, we
|
58 |
+
stand ready to extend this provision to those domains in future versions
|
59 |
+
of the GPL, as needed to protect the freedom of users.
|
60 |
+
|
61 |
+
Finally, every program is threatened constantly by software patents.
|
62 |
+
States should not allow patents to restrict development and use of
|
63 |
+
software on general-purpose computers, but in those that do, we wish to
|
64 |
+
avoid the special danger that patents applied to a free program could
|
65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
66 |
+
patents cannot be used to render the program non-free.
|
67 |
+
|
68 |
+
The precise terms and conditions for copying, distribution and
|
69 |
+
modification follow.
|
70 |
+
|
71 |
+
TERMS AND CONDITIONS
|
72 |
+
|
73 |
+
0. Definitions.
|
74 |
+
|
75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
76 |
+
|
77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
78 |
+
works, such as semiconductor masks.
|
79 |
+
|
80 |
+
"The Program" refers to any copyrightable work licensed under this
|
81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
82 |
+
"recipients" may be individuals or organizations.
|
83 |
+
|
84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
85 |
+
in a fashion requiring copyright permission, other than the making of an
|
86 |
+
exact copy. The resulting work is called a "modified version" of the
|
87 |
+
earlier work or a work "based on" the earlier work.
|
88 |
+
|
89 |
+
A "covered work" means either the unmodified Program or a work based
|
90 |
+
on the Program.
|
91 |
+
|
92 |
+
To "propagate" a work means to do anything with it that, without
|
93 |
+
permission, would make you directly or secondarily liable for
|
94 |
+
infringement under applicable copyright law, except executing it on a
|
95 |
+
computer or modifying a private copy. Propagation includes copying,
|
96 |
+
distribution (with or without modification), making available to the
|
97 |
+
public, and in some countries other activities as well.
|
98 |
+
|
99 |
+
To "convey" a work means any kind of propagation that enables other
|
100 |
+
parties to make or receive copies. Mere interaction with a user through
|
101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
102 |
+
|
103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
104 |
+
to the extent that it includes a convenient and prominently visible
|
105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
106 |
+
tells the user that there is no warranty for the work (except to the
|
107 |
+
extent that warranties are provided), that licensees may convey the
|
108 |
+
work under this License, and how to view a copy of this License. If
|
109 |
+
the interface presents a list of user commands or options, such as a
|
110 |
+
menu, a prominent item in the list meets this criterion.
|
111 |
+
|
112 |
+
1. Source Code.
|
113 |
+
|
114 |
+
The "source code" for a work means the preferred form of the work
|
115 |
+
for making modifications to it. "Object code" means any non-source
|
116 |
+
form of a work.
|
117 |
+
|
118 |
+
A "Standard Interface" means an interface that either is an official
|
119 |
+
standard defined by a recognized standards body, or, in the case of
|
120 |
+
interfaces specified for a particular programming language, one that
|
121 |
+
is widely used among developers working in that language.
|
122 |
+
|
123 |
+
The "System Libraries" of an executable work include anything, other
|
124 |
+
than the work as a whole, that (a) is included in the normal form of
|
125 |
+
packaging a Major Component, but which is not part of that Major
|
126 |
+
Component, and (b) serves only to enable use of the work with that
|
127 |
+
Major Component, or to implement a Standard Interface for which an
|
128 |
+
implementation is available to the public in source code form. A
|
129 |
+
"Major Component", in this context, means a major essential component
|
130 |
+
(kernel, window system, and so on) of the specific operating system
|
131 |
+
(if any) on which the executable work runs, or a compiler used to
|
132 |
+
produce the work, or an object code interpreter used to run it.
|
133 |
+
|
134 |
+
The "Corresponding Source" for a work in object code form means all
|
135 |
+
the source code needed to generate, install, and (for an executable
|
136 |
+
work) run the object code and to modify the work, including scripts to
|
137 |
+
control those activities. However, it does not include the work's
|
138 |
+
System Libraries, or general-purpose tools or generally available free
|
139 |
+
programs which are used unmodified in performing those activities but
|
140 |
+
which are not part of the work. For example, Corresponding Source
|
141 |
+
includes interface definition files associated with source files for
|
142 |
+
the work, and the source code for shared libraries and dynamically
|
143 |
+
linked subprograms that the work is specifically designed to require,
|
144 |
+
such as by intimate data communication or control flow between those
|
145 |
+
subprograms and other parts of the work.
|
146 |
+
|
147 |
+
The Corresponding Source need not include anything that users
|
148 |
+
can regenerate automatically from other parts of the Corresponding
|
149 |
+
Source.
|
150 |
+
|
151 |
+
The Corresponding Source for a work in source code form is that
|
152 |
+
same work.
|
153 |
+
|
154 |
+
2. Basic Permissions.
|
155 |
+
|
156 |
+
All rights granted under this License are granted for the term of
|
157 |
+
copyright on the Program, and are irrevocable provided the stated
|
158 |
+
conditions are met. This License explicitly affirms your unlimited
|
159 |
+
permission to run the unmodified Program. The output from running a
|
160 |
+
covered work is covered by this License only if the output, given its
|
161 |
+
content, constitutes a covered work. This License acknowledges your
|
162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
163 |
+
|
164 |
+
You may make, run and propagate covered works that you do not
|
165 |
+
convey, without conditions so long as your license otherwise remains
|
166 |
+
in force. You may convey covered works to others for the sole purpose
|
167 |
+
of having them make modifications exclusively for you, or provide you
|
168 |
+
with facilities for running those works, provided that you comply with
|
169 |
+
the terms of this License in conveying all material for which you do
|
170 |
+
not control copyright. Those thus making or running the covered works
|
171 |
+
for you must do so exclusively on your behalf, under your direction
|
172 |
+
and control, on terms that prohibit them from making any copies of
|
173 |
+
your copyrighted material outside their relationship with you.
|
174 |
+
|
175 |
+
Conveying under any other circumstances is permitted solely under
|
176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
177 |
+
makes it unnecessary.
|
178 |
+
|
179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
180 |
+
|
181 |
+
No covered work shall be deemed part of an effective technological
|
182 |
+
measure under any applicable law fulfilling obligations under article
|
183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
184 |
+
similar laws prohibiting or restricting circumvention of such
|
185 |
+
measures.
|
186 |
+
|
187 |
+
When you convey a covered work, you waive any legal power to forbid
|
188 |
+
circumvention of technological measures to the extent such circumvention
|
189 |
+
is effected by exercising rights under this License with respect to
|
190 |
+
the covered work, and you disclaim any intention to limit operation or
|
191 |
+
modification of the work as a means of enforcing, against the work's
|
192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
193 |
+
technological measures.
|
194 |
+
|
195 |
+
4. Conveying Verbatim Copies.
|
196 |
+
|
197 |
+
You may convey verbatim copies of the Program's source code as you
|
198 |
+
receive it, in any medium, provided that you conspicuously and
|
199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
200 |
+
keep intact all notices stating that this License and any
|
201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
202 |
+
keep intact all notices of the absence of any warranty; and give all
|
203 |
+
recipients a copy of this License along with the Program.
|
204 |
+
|
205 |
+
You may charge any price or no price for each copy that you convey,
|
206 |
+
and you may offer support or warranty protection for a fee.
|
207 |
+
|
208 |
+
5. Conveying Modified Source Versions.
|
209 |
+
|
210 |
+
You may convey a work based on the Program, or the modifications to
|
211 |
+
produce it from the Program, in the form of source code under the
|
212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
213 |
+
|
214 |
+
a) The work must carry prominent notices stating that you modified
|
215 |
+
it, and giving a relevant date.
|
216 |
+
|
217 |
+
b) The work must carry prominent notices stating that it is
|
218 |
+
released under this License and any conditions added under section
|
219 |
+
7. This requirement modifies the requirement in section 4 to
|
220 |
+
"keep intact all notices".
|
221 |
+
|
222 |
+
c) You must license the entire work, as a whole, under this
|
223 |
+
License to anyone who comes into possession of a copy. This
|
224 |
+
License will therefore apply, along with any applicable section 7
|
225 |
+
additional terms, to the whole of the work, and all its parts,
|
226 |
+
regardless of how they are packaged. This License gives no
|
227 |
+
permission to license the work in any other way, but it does not
|
228 |
+
invalidate such permission if you have separately received it.
|
229 |
+
|
230 |
+
d) If the work has interactive user interfaces, each must display
|
231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
233 |
+
work need not make them do so.
|
234 |
+
|
235 |
+
A compilation of a covered work with other separate and independent
|
236 |
+
works, which are not by their nature extensions of the covered work,
|
237 |
+
and which are not combined with it such as to form a larger program,
|
238 |
+
in or on a volume of a storage or distribution medium, is called an
|
239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
240 |
+
used to limit the access or legal rights of the compilation's users
|
241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
242 |
+
in an aggregate does not cause this License to apply to the other
|
243 |
+
parts of the aggregate.
|
244 |
+
|
245 |
+
6. Conveying Non-Source Forms.
|
246 |
+
|
247 |
+
You may convey a covered work in object code form under the terms
|
248 |
+
of sections 4 and 5, provided that you also convey the
|
249 |
+
machine-readable Corresponding Source under the terms of this License,
|
250 |
+
in one of these ways:
|
251 |
+
|
252 |
+
a) Convey the object code in, or embodied in, a physical product
|
253 |
+
(including a physical distribution medium), accompanied by the
|
254 |
+
Corresponding Source fixed on a durable physical medium
|
255 |
+
customarily used for software interchange.
|
256 |
+
|
257 |
+
b) Convey the object code in, or embodied in, a physical product
|
258 |
+
(including a physical distribution medium), accompanied by a
|
259 |
+
written offer, valid for at least three years and valid for as
|
260 |
+
long as you offer spare parts or customer support for that product
|
261 |
+
model, to give anyone who possesses the object code either (1) a
|
262 |
+
copy of the Corresponding Source for all the software in the
|
263 |
+
product that is covered by this License, on a durable physical
|
264 |
+
medium customarily used for software interchange, for a price no
|
265 |
+
more than your reasonable cost of physically performing this
|
266 |
+
conveying of source, or (2) access to copy the
|
267 |
+
Corresponding Source from a network server at no charge.
|
268 |
+
|
269 |
+
c) Convey individual copies of the object code with a copy of the
|
270 |
+
written offer to provide the Corresponding Source. This
|
271 |
+
alternative is allowed only occasionally and noncommercially, and
|
272 |
+
only if you received the object code with such an offer, in accord
|
273 |
+
with subsection 6b.
|
274 |
+
|
275 |
+
d) Convey the object code by offering access from a designated
|
276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
277 |
+
Corresponding Source in the same way through the same place at no
|
278 |
+
further charge. You need not require recipients to copy the
|
279 |
+
Corresponding Source along with the object code. If the place to
|
280 |
+
copy the object code is a network server, the Corresponding Source
|
281 |
+
may be on a different server (operated by you or a third party)
|
282 |
+
that supports equivalent copying facilities, provided you maintain
|
283 |
+
clear directions next to the object code saying where to find the
|
284 |
+
Corresponding Source. Regardless of what server hosts the
|
285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
286 |
+
available for as long as needed to satisfy these requirements.
|
287 |
+
|
288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
289 |
+
you inform other peers where the object code and Corresponding
|
290 |
+
Source of the work are being offered to the general public at no
|
291 |
+
charge under subsection 6d.
|
292 |
+
|
293 |
+
A separable portion of the object code, whose source code is excluded
|
294 |
+
from the Corresponding Source as a System Library, need not be
|
295 |
+
included in conveying the object code work.
|
296 |
+
|
297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
298 |
+
tangible personal property which is normally used for personal, family,
|
299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
302 |
+
product received by a particular user, "normally used" refers to a
|
303 |
+
typical or common use of that class of product, regardless of the status
|
304 |
+
of the particular user or of the way in which the particular user
|
305 |
+
actually uses, or expects or is expected to use, the product. A product
|
306 |
+
is a consumer product regardless of whether the product has substantial
|
307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
308 |
+
the only significant mode of use of the product.
|
309 |
+
|
310 |
+
"Installation Information" for a User Product means any methods,
|
311 |
+
procedures, authorization keys, or other information required to install
|
312 |
+
and execute modified versions of a covered work in that User Product from
|
313 |
+
a modified version of its Corresponding Source. The information must
|
314 |
+
suffice to ensure that the continued functioning of the modified object
|
315 |
+
code is in no case prevented or interfered with solely because
|
316 |
+
modification has been made.
|
317 |
+
|
318 |
+
If you convey an object code work under this section in, or with, or
|
319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
320 |
+
part of a transaction in which the right of possession and use of the
|
321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
322 |
+
fixed term (regardless of how the transaction is characterized), the
|
323 |
+
Corresponding Source conveyed under this section must be accompanied
|
324 |
+
by the Installation Information. But this requirement does not apply
|
325 |
+
if neither you nor any third party retains the ability to install
|
326 |
+
modified object code on the User Product (for example, the work has
|
327 |
+
been installed in ROM).
|
328 |
+
|
329 |
+
The requirement to provide Installation Information does not include a
|
330 |
+
requirement to continue to provide support service, warranty, or updates
|
331 |
+
for a work that has been modified or installed by the recipient, or for
|
332 |
+
the User Product in which it has been modified or installed. Access to a
|
333 |
+
network may be denied when the modification itself materially and
|
334 |
+
adversely affects the operation of the network or violates the rules and
|
335 |
+
protocols for communication across the network.
|
336 |
+
|
337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
338 |
+
in accord with this section must be in a format that is publicly
|
339 |
+
documented (and with an implementation available to the public in
|
340 |
+
source code form), and must require no special password or key for
|
341 |
+
unpacking, reading or copying.
|
342 |
+
|
343 |
+
7. Additional Terms.
|
344 |
+
|
345 |
+
"Additional permissions" are terms that supplement the terms of this
|
346 |
+
License by making exceptions from one or more of its conditions.
|
347 |
+
Additional permissions that are applicable to the entire Program shall
|
348 |
+
be treated as though they were included in this License, to the extent
|
349 |
+
that they are valid under applicable law. If additional permissions
|
350 |
+
apply only to part of the Program, that part may be used separately
|
351 |
+
under those permissions, but the entire Program remains governed by
|
352 |
+
this License without regard to the additional permissions.
|
353 |
+
|
354 |
+
When you convey a copy of a covered work, you may at your option
|
355 |
+
remove any additional permissions from that copy, or from any part of
|
356 |
+
it. (Additional permissions may be written to require their own
|
357 |
+
removal in certain cases when you modify the work.) You may place
|
358 |
+
additional permissions on material, added by you to a covered work,
|
359 |
+
for which you have or can give appropriate copyright permission.
|
360 |
+
|
361 |
+
Notwithstanding any other provision of this License, for material you
|
362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
363 |
+
that material) supplement the terms of this License with terms:
|
364 |
+
|
365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
366 |
+
terms of sections 15 and 16 of this License; or
|
367 |
+
|
368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
369 |
+
author attributions in that material or in the Appropriate Legal
|
370 |
+
Notices displayed by works containing it; or
|
371 |
+
|
372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
373 |
+
requiring that modified versions of such material be marked in
|
374 |
+
reasonable ways as different from the original version; or
|
375 |
+
|
376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
377 |
+
authors of the material; or
|
378 |
+
|
379 |
+
e) Declining to grant rights under trademark law for use of some
|
380 |
+
trade names, trademarks, or service marks; or
|
381 |
+
|
382 |
+
f) Requiring indemnification of licensors and authors of that
|
383 |
+
material by anyone who conveys the material (or modified versions of
|
384 |
+
it) with contractual assumptions of liability to the recipient, for
|
385 |
+
any liability that these contractual assumptions directly impose on
|
386 |
+
those licensors and authors.
|
387 |
+
|
388 |
+
All other non-permissive additional terms are considered "further
|
389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
390 |
+
received it, or any part of it, contains a notice stating that it is
|
391 |
+
governed by this License along with a term that is a further
|
392 |
+
restriction, you may remove that term. If a license document contains
|
393 |
+
a further restriction but permits relicensing or conveying under this
|
394 |
+
License, you may add to a covered work material governed by the terms
|
395 |
+
of that license document, provided that the further restriction does
|
396 |
+
not survive such relicensing or conveying.
|
397 |
+
|
398 |
+
If you add terms to a covered work in accord with this section, you
|
399 |
+
must place, in the relevant source files, a statement of the
|
400 |
+
additional terms that apply to those files, or a notice indicating
|
401 |
+
where to find the applicable terms.
|
402 |
+
|
403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
404 |
+
form of a separately written license, or stated as exceptions;
|
405 |
+
the above requirements apply either way.
|
406 |
+
|
407 |
+
8. Termination.
|
408 |
+
|
409 |
+
You may not propagate or modify a covered work except as expressly
|
410 |
+
provided under this License. Any attempt otherwise to propagate or
|
411 |
+
modify it is void, and will automatically terminate your rights under
|
412 |
+
this License (including any patent licenses granted under the third
|
413 |
+
paragraph of section 11).
|
414 |
+
|
415 |
+
However, if you cease all violation of this License, then your
|
416 |
+
license from a particular copyright holder is reinstated (a)
|
417 |
+
provisionally, unless and until the copyright holder explicitly and
|
418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
419 |
+
holder fails to notify you of the violation by some reasonable means
|
420 |
+
prior to 60 days after the cessation.
|
421 |
+
|
422 |
+
Moreover, your license from a particular copyright holder is
|
423 |
+
reinstated permanently if the copyright holder notifies you of the
|
424 |
+
violation by some reasonable means, this is the first time you have
|
425 |
+
received notice of violation of this License (for any work) from that
|
426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
427 |
+
your receipt of the notice.
|
428 |
+
|
429 |
+
Termination of your rights under this section does not terminate the
|
430 |
+
licenses of parties who have received copies or rights from you under
|
431 |
+
this License. If your rights have been terminated and not permanently
|
432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
433 |
+
material under section 10.
|
434 |
+
|
435 |
+
9. Acceptance Not Required for Having Copies.
|
436 |
+
|
437 |
+
You are not required to accept this License in order to receive or
|
438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
440 |
+
to receive a copy likewise does not require acceptance. However,
|
441 |
+
nothing other than this License grants you permission to propagate or
|
442 |
+
modify any covered work. These actions infringe copyright if you do
|
443 |
+
not accept this License. Therefore, by modifying or propagating a
|
444 |
+
covered work, you indicate your acceptance of this License to do so.
|
445 |
+
|
446 |
+
10. Automatic Licensing of Downstream Recipients.
|
447 |
+
|
448 |
+
Each time you convey a covered work, the recipient automatically
|
449 |
+
receives a license from the original licensors, to run, modify and
|
450 |
+
propagate that work, subject to this License. You are not responsible
|
451 |
+
for enforcing compliance by third parties with this License.
|
452 |
+
|
453 |
+
An "entity transaction" is a transaction transferring control of an
|
454 |
+
organization, or substantially all assets of one, or subdividing an
|
455 |
+
organization, or merging organizations. If propagation of a covered
|
456 |
+
work results from an entity transaction, each party to that
|
457 |
+
transaction who receives a copy of the work also receives whatever
|
458 |
+
licenses to the work the party's predecessor in interest had or could
|
459 |
+
give under the previous paragraph, plus a right to possession of the
|
460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
461 |
+
the predecessor has it or can get it with reasonable efforts.
|
462 |
+
|
463 |
+
You may not impose any further restrictions on the exercise of the
|
464 |
+
rights granted or affirmed under this License. For example, you may
|
465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
466 |
+
rights granted under this License, and you may not initiate litigation
|
467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
468 |
+
any patent claim is infringed by making, using, selling, offering for
|
469 |
+
sale, or importing the Program or any portion of it.
|
470 |
+
|
471 |
+
11. Patents.
|
472 |
+
|
473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
474 |
+
License of the Program or a work on which the Program is based. The
|
475 |
+
work thus licensed is called the contributor's "contributor version".
|
476 |
+
|
477 |
+
A contributor's "essential patent claims" are all patent claims
|
478 |
+
owned or controlled by the contributor, whether already acquired or
|
479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
480 |
+
by this License, of making, using, or selling its contributor version,
|
481 |
+
but do not include claims that would be infringed only as a
|
482 |
+
consequence of further modification of the contributor version. For
|
483 |
+
purposes of this definition, "control" includes the right to grant
|
484 |
+
patent sublicenses in a manner consistent with the requirements of
|
485 |
+
this License.
|
486 |
+
|
487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
488 |
+
patent license under the contributor's essential patent claims, to
|
489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
490 |
+
propagate the contents of its contributor version.
|
491 |
+
|
492 |
+
In the following three paragraphs, a "patent license" is any express
|
493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
494 |
+
(such as an express permission to practice a patent or covenant not to
|
495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
496 |
+
party means to make such an agreement or commitment not to enforce a
|
497 |
+
patent against the party.
|
498 |
+
|
499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
500 |
+
and the Corresponding Source of the work is not available for anyone
|
501 |
+
to copy, free of charge and under the terms of this License, through a
|
502 |
+
publicly available network server or other readily accessible means,
|
503 |
+
then you must either (1) cause the Corresponding Source to be so
|
504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
506 |
+
consistent with the requirements of this License, to extend the patent
|
507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
508 |
+
actual knowledge that, but for the patent license, your conveying the
|
509 |
+
covered work in a country, or your recipient's use of the covered work
|
510 |
+
in a country, would infringe one or more identifiable patents in that
|
511 |
+
country that you have reason to believe are valid.
|
512 |
+
|
513 |
+
If, pursuant to or in connection with a single transaction or
|
514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
515 |
+
covered work, and grant a patent license to some of the parties
|
516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
517 |
+
or convey a specific copy of the covered work, then the patent license
|
518 |
+
you grant is automatically extended to all recipients of the covered
|
519 |
+
work and works based on it.
|
520 |
+
|
521 |
+
A patent license is "discriminatory" if it does not include within
|
522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
524 |
+
specifically granted under this License. You may not convey a covered
|
525 |
+
work if you are a party to an arrangement with a third party that is
|
526 |
+
in the business of distributing software, under which you make payment
|
527 |
+
to the third party based on the extent of your activity of conveying
|
528 |
+
the work, and under which the third party grants, to any of the
|
529 |
+
parties who would receive the covered work from you, a discriminatory
|
530 |
+
patent license (a) in connection with copies of the covered work
|
531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
532 |
+
for and in connection with specific products or compilations that
|
533 |
+
contain the covered work, unless you entered into that arrangement,
|
534 |
+
or that patent license was granted, prior to 28 March 2007.
|
535 |
+
|
536 |
+
Nothing in this License shall be construed as excluding or limiting
|
537 |
+
any implied license or other defenses to infringement that may
|
538 |
+
otherwise be available to you under applicable patent law.
|
539 |
+
|
540 |
+
12. No Surrender of Others' Freedom.
|
541 |
+
|
542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
543 |
+
otherwise) that contradict the conditions of this License, they do not
|
544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
546 |
+
License and any other pertinent obligations, then as a consequence you may
|
547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
548 |
+
to collect a royalty for further conveying from those to whom you convey
|
549 |
+
the Program, the only way you could satisfy both those terms and this
|
550 |
+
License would be to refrain entirely from conveying the Program.
|
551 |
+
|
552 |
+
13. Use with the GNU Affero General Public License.
|
553 |
+
|
554 |
+
Notwithstanding any other provision of this License, you have
|
555 |
+
permission to link or combine any covered work with a work licensed
|
556 |
+
under version 3 of the GNU Affero General Public License into a single
|
557 |
+
combined work, and to convey the resulting work. The terms of this
|
558 |
+
License will continue to apply to the part which is the covered work,
|
559 |
+
but the special requirements of the GNU Affero General Public License,
|
560 |
+
section 13, concerning interaction through a network will apply to the
|
561 |
+
combination as such.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU General Public License from time to time. Such new versions will
|
567 |
+
be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
635 |
+
Copyright (C) <year> <name of author>
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU General Public License
|
648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If the program does terminal interaction, make it output a short
|
653 |
+
notice like this when it starts in an interactive mode:
|
654 |
+
|
655 |
+
<program> Copyright (C) <year> <name of author>
|
656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
657 |
+
This is free software, and you are welcome to redistribute it
|
658 |
+
under certain conditions; type `show c' for details.
|
659 |
+
|
660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
661 |
+
parts of the General Public License. Of course, your program's commands
|
662 |
+
might be different; for a GUI interface, you would use an "about box".
|
663 |
+
|
664 |
+
You should also get your employer (if you work as a programmer) or school,
|
665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
667 |
+
<https://www.gnu.org/licenses/>.
|
668 |
+
|
669 |
+
The GNU General Public License does not permit incorporating your program
|
670 |
+
into proprietary programs. If your program is a subroutine library, you
|
671 |
+
may consider it more useful to permit linking proprietary applications with
|
672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
673 |
+
Public License instead of this License. But first, please read
|
674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
README.md
CHANGED
@@ -1,3 +1,174 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ClimateGAN: Raising Awareness about Climate Change by Generating Images of Floods
|
2 |
+
|
3 |
+
This repository contains the code used to train the model presented in our **[paper](https://arxiv.org/abs/2110.02871)**.
|
4 |
+
|
5 |
+
It is not simply a presentation repository but the code we have used over the past 30 months to come to our final architecture. As such, you will find many scripts, classes, blocks and options which we actively use for our own development purposes but are not directly relevant to reproduce results or use pretrained weights.
|
6 |
+
|
7 |
+
![flood processing](images/flood.png)
|
8 |
+
|
9 |
+
If you use this code, data or pre-trained weights, please cite our ICLR 2022 paper:
|
10 |
+
|
11 |
+
```
|
12 |
+
@inproceedings{schmidt2022climategan,
|
13 |
+
title = {Climate{GAN}: Raising Climate Change Awareness by Generating Images of Floods},
|
14 |
+
author = {Victor Schmidt and Alexandra Luccioni and M{\'e}lisande Teng and Tianyu Zhang and Alexia Reynaud and Sunand Raghupathi and Gautier Cosne and Adrien Juraver and Vahe Vardanyan and Alex Hern{\'a}ndez-Garc{\'\i}a and Yoshua Bengio},
|
15 |
+
booktitle = {International Conference on Learning Representations},
|
16 |
+
year = {2022},
|
17 |
+
url = {https://openreview.net/forum?id=EZNOb_uNpJk}
|
18 |
+
}
|
19 |
+
```
|
20 |
+
|
21 |
+
## Using pre-trained weights
|
22 |
+
|
23 |
+
In the paper, we present ClimateGAN as a solution to produce images of floods. It can actually do **more**:
|
24 |
+
|
25 |
+
* reusing the segmentation map, we are able to isolate the sky, turn it red and in a few more steps create an image resembling the consequences of a wildfire on a neighboring area, similarly to the [California wildfires](https://www.google.com/search?q=california+wildfires+red+sky&source=lnms&tbm=isch&sa=X&ved=2ahUKEwisws-hx7zxAhXxyYUKHQyKBUwQ_AUoAXoECAEQBA&biw=1680&bih=917&dpr=2).
|
26 |
+
* reusing the depth map, we can simulate the consequences of a smog event on an image, scaling the intensity of the filter by the distance of an object to the camera, as per [HazeRD](http://www2.ece.rochester.edu/~gsharma/papers/Zhang_ICIP2017_HazeRD.pdf)
|
27 |
+
|
28 |
+
![image of wildfire processing](images/wildfire.png)
|
29 |
+
![image of smog processing](images/smog.png)
|
30 |
+
|
31 |
+
In this section we'll explain how to produce the `Painted Input` along with the Smog and Wildfire outputs of a pre-trained ClimateGAN model.
|
32 |
+
|
33 |
+
### Installation
|
34 |
+
|
35 |
+
This repository and associated model have been developed using Python 3.8.2 and **Pytorch 1.7.0**.
|
36 |
+
|
37 |
+
```bash
|
38 |
+
$ git clone [email protected]:cc-ai/climategan.git
|
39 |
+
$ cd climategan
|
40 |
+
$ pip install -r requirements-3.8.2.txt # or `requirements-any.txt` for other Python versions (not tested but expected to be fine)
|
41 |
+
```
|
42 |
+
|
43 |
+
Our pipeline uses [comet.ml](https://comet.ml) to log images. You don't *have* to use their services but we recommend you do as images can be uploaded on your workspace instead of being written to disk.
|
44 |
+
|
45 |
+
If you want to use Comet, make sure you have the [appropriate configuration in place (API key and workspace at least)](https://www.comet.ml/docs/python-sdk/advanced/#non-interactive-setup)
|
46 |
+
|
47 |
+
### Inference
|
48 |
+
|
49 |
+
1. Download and unzip the weights [from this link](https://drive.google.com/u/0/uc?id=18OCUIy7JQ2Ow_-cC5xn_hhDn-Bp45N1K&export=download) (checkout [`gdown`](https://github.com/wkentaro/gdown) for a commandline interface) and put them in `config/`
|
50 |
+
|
51 |
+
```
|
52 |
+
$ pip install gdown
|
53 |
+
$ mkdir config
|
54 |
+
$ cd config
|
55 |
+
$ gdown https://drive.google.com/u/0/uc?id=18OCUIy7JQ2Ow_-cC5xn_hhDn-Bp45N1K
|
56 |
+
$ unzip release-github-v1.zip
|
57 |
+
$ cd ..
|
58 |
+
```
|
59 |
+
|
60 |
+
2. Run from the repo's root:
|
61 |
+
|
62 |
+
1. With `comet`:
|
63 |
+
|
64 |
+
```bash
|
65 |
+
python apply_events.py --batch_size 4 --half --images_paths path/to/a/folder --resume_path config/model/masker --upload
|
66 |
+
```
|
67 |
+
|
68 |
+
2. Without `comet` (and shortened args compared to the previous example):
|
69 |
+
|
70 |
+
```bash
|
71 |
+
python apply_events.py -b 4 --half -i path/to/a/folder -r config/model/masker --output_path path/to/a/folder
|
72 |
+
```
|
73 |
+
|
74 |
+
The `apply_events.py` script has many options, for instance to use a different output size than the default systematic `640 x 640` pixels, look at the code or `python apply_events.py --help`.
|
75 |
+
|
76 |
+
## Training from scratch
|
77 |
+
|
78 |
+
ClimateGAN is split in two main components: the Masker producing a binary mask of where water should go and the Painter generating water within this mask given an initial image's context.
|
79 |
+
|
80 |
+
### Configuration
|
81 |
+
|
82 |
+
The code is structured to use `shared/trainer/defaults.yaml` as default configuration. There are 2 ways of overriding those for your purposes (without altering that file):
|
83 |
+
|
84 |
+
1. By providing an alternative configuration as command line argument `config=path/to/config.yaml`
|
85 |
+
|
86 |
+
1. The code will first load `shared/trainer/defaults.yaml`
|
87 |
+
2. *then* update the resulting dictionary with values read in the provided `config` argument.
|
88 |
+
3. The folder `config/` is NOT tracked by git so you would typically put them there
|
89 |
+
|
90 |
+
2. By overwriting specific arguments from the command-line like `python train.py data.loaders.batch_size=8`
|
91 |
+
|
92 |
+
|
93 |
+
### Data
|
94 |
+
|
95 |
+
#### Masker
|
96 |
+
|
97 |
+
##### Real Images
|
98 |
+
|
99 |
+
Because of copyrights issues we are not able to share the real images scrapped from the internet. You would have to do that yourself. In the `yaml` config file, the code expects a key pointing to a `json` file like `data.files.<train or val>.r: <path/to/a/json/file>`. This `json` file should be a list of dictionaries with tasks as keys and files as values. Example:
|
100 |
+
|
101 |
+
```json
|
102 |
+
[
|
103 |
+
{
|
104 |
+
"x": "path/to/a/real/image",
|
105 |
+
"s": "path/to/a/segmentation_map",
|
106 |
+
"d": "path/to/a/depth_map"
|
107 |
+
},
|
108 |
+
...
|
109 |
+
]
|
110 |
+
```
|
111 |
+
|
112 |
+
Following the [ADVENT](https://github.com/valeoai/ADVENT) procedure, only `x` should be required. We use `s` and `d` inferred from pre-trained models (DeepLab v3+ and MiDAS) to use those pseudo-labels in the first epochs of training (see `pseudo:` in the config file)
|
113 |
+
|
114 |
+
##### Simulated Images
|
115 |
+
|
116 |
+
We share snapshots of the Virtual World we created in the [Mila-Simulated-Flood dataset](). You can download and unzip one water-level and then produce json files similar to that of the real data, with an additional key `"m": "path/to/a/ground_truth_sim_mask"`. Lastly, edit the config file: `data.files.<train or val>.s: <path/to/a/json/file>`
|
117 |
+
|
118 |
+
#### Painter
|
119 |
+
|
120 |
+
The painter expects input images and binary masks to train using the [GauGAN](https://github.com/NVlabs/SPADE) training procedure. Unfortunately we cannot share openly the collected data, but similarly as for the Masker's real data you would point to the data using a `json` file as:
|
121 |
+
|
122 |
+
```json
|
123 |
+
[
|
124 |
+
{
|
125 |
+
"x": "path/to/a/real/image",
|
126 |
+
"m": "path/to/a/water_mask",
|
127 |
+
},
|
128 |
+
...
|
129 |
+
]
|
130 |
+
```
|
131 |
+
|
132 |
+
And put those files as values to `data.files.<train or val>.rf: <path/to/a/json/file>` in the configuration.
|
133 |
+
|
134 |
+
## Coding conventions
|
135 |
+
|
136 |
+
* Tasks
|
137 |
+
* `x` is an input image, in [-1, 1]
|
138 |
+
* `s` is a segmentation target with `long` classes
|
139 |
+
* `d` is a depth map target in R, may be actually `log(depth)` or `1/depth`
|
140 |
+
* `m` is a binary mask with 1s where water is/should be
|
141 |
+
* Domains
|
142 |
+
* `r` is the *real* domain for the masker. Input images are real pictures of urban/suburban/rural areas
|
143 |
+
* `s` is the *simulated* domain for the masker. Input images are taken from our Unity world
|
144 |
+
* `rf` is the *real flooded* domain for the painter. Training images are pairs `(x, m)` of flooded scenes for which the water should be reconstructed, in the validation data input images are not flooded and we provide a manually labeled mask `m`
|
145 |
+
* `kitti` is a special `s` domain to pre-train the masker on [Virtual Kitti 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/)
|
146 |
+
* it alters the `trainer.loaders` dict to select relevant data sources from `trainer.all_loaders` in `trainer.switch_data()`. The rest of the code is identical.
|
147 |
+
* Flow
|
148 |
+
* This describes the call stack for the trainers standard training procedure
|
149 |
+
* `train()`
|
150 |
+
* `run_epoch()`
|
151 |
+
* `update_G()`
|
152 |
+
* `zero_grad(G)`
|
153 |
+
* `get_G_loss()`
|
154 |
+
* `get_masker_loss()`
|
155 |
+
* `masker_m_loss()` -> masking loss
|
156 |
+
* `masker_s_loss()` -> segmentation loss
|
157 |
+
* `masker_d_loss()` -> depth estimation loss
|
158 |
+
* `get_painter_loss()` -> painter's loss
|
159 |
+
* `g_loss.backward()`
|
160 |
+
* `g_opt_step()`
|
161 |
+
* `update_D()`
|
162 |
+
* `zero_grad(D)`
|
163 |
+
* `get_D_loss()`
|
164 |
+
* painter's disc losses
|
165 |
+
* `masker_m_loss()` -> masking AdvEnt disc loss
|
166 |
+
* `masker_s_loss()` -> segmentation AdvEnt disc loss
|
167 |
+
* `d_loss.backward()`
|
168 |
+
* `d_opt_step()`
|
169 |
+
* `update_learning_rates()` -> update learning rates according to schedules defined in `opts.gen.opt` and `opts.dis.opt`
|
170 |
+
* `run_validation()`
|
171 |
+
* compute val losses
|
172 |
+
* `eval_images()` -> compute metrics
|
173 |
+
* `log_comet_images()` -> compute and upload inferences
|
174 |
+
* `save()`
|
USAGE.md
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ClimateGAN
|
2 |
+
- [ClimateGAN](#climategan)
|
3 |
+
- [Setup](#setup)
|
4 |
+
- [Coding conventions](#coding-conventions)
|
5 |
+
- [updates](#updates)
|
6 |
+
- [interfaces](#interfaces)
|
7 |
+
- [Logging on comet](#logging-on-comet)
|
8 |
+
- [Resources](#resources)
|
9 |
+
- [Example](#example)
|
10 |
+
- [Release process](#release-process)
|
11 |
+
|
12 |
+
## Setup
|
13 |
+
|
14 |
+
**`PyTorch >= 1.1.0`** otherwise optimizer.step() and scheduler.step() are in the wrong order ([docs](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate))
|
15 |
+
|
16 |
+
**pytorch==1.6** to use pytorch-xla or automatic mixed precision (`amp` branch).
|
17 |
+
|
18 |
+
Configuration files use the **YAML** syntax. If you don't know what `&` and `<<` mean, you'll have a hard time reading the files. Have a look at:
|
19 |
+
|
20 |
+
* https://dev.to/paulasantamaria/introduction-to-yaml-125f
|
21 |
+
* https://stackoverflow.com/questions/41063361/what-is-the-double-left-arrow-syntax-in-yaml-called-and-wheres-it-specced/41065222
|
22 |
+
|
23 |
+
**pip**
|
24 |
+
|
25 |
+
```
|
26 |
+
$ pip install comet_ml scipy opencv-python torch torchvision omegaconf==1.4.1 hydra-core==0.11.3 scikit-image imageio addict tqdm torch_optimizer
|
27 |
+
```
|
28 |
+
|
29 |
+
## Coding conventions
|
30 |
+
|
31 |
+
* Tasks
|
32 |
+
* `x` is an input image, in [-1, 1]
|
33 |
+
* `s` is a segmentation target with `long` classes
|
34 |
+
* `d` is a depth map target in R, may be actually `log(depth)` or `1/depth`
|
35 |
+
* `m` is a binary mask with 1s where water is/should be
|
36 |
+
* Domains
|
37 |
+
* `r` is the *real* domain for the masker. Input images are real pictures of urban/suburban/rural areas
|
38 |
+
* `s` is the *simulated* domain for the masker. Input images are taken from our Unity world
|
39 |
+
* `rf` is the *real flooded* domain for the painter. Training images are pairs `(x, m)` of flooded scenes for which the water should be reconstructed, in the validation data input images are not flooded and we provide a manually labeled mask `m`
|
40 |
+
* `kitti` is a special `s` domain to pre-train the masker on [Virtual Kitti 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/)
|
41 |
+
* it alters the `trainer.loaders` dict to select relevant data sources from `trainer.all_loaders` in `trainer.switch_data()`. The rest of the code is identical.
|
42 |
+
* Flow
|
43 |
+
* This describes the call stack for the trainers standard training procedure
|
44 |
+
* `train()`
|
45 |
+
* `run_epoch()`
|
46 |
+
* `update_G()`
|
47 |
+
* `zero_grad(G)`
|
48 |
+
* `get_G_loss()`
|
49 |
+
* `get_masker_loss()`
|
50 |
+
* `masker_m_loss()` -> masking loss
|
51 |
+
* `masker_s_loss()` -> segmentation loss
|
52 |
+
* `masker_d_loss()` -> depth estimation loss
|
53 |
+
* `get_painter_loss()` -> painter's loss
|
54 |
+
* `g_loss.backward()`
|
55 |
+
* `g_opt_step()`
|
56 |
+
* `update_D()`
|
57 |
+
* `zero_grad(D)`
|
58 |
+
* `get_D_loss()`
|
59 |
+
* painter's disc losses
|
60 |
+
* `masker_m_loss()` -> masking AdvEnt disc loss
|
61 |
+
* `masker_s_loss()` -> segmentation AdvEnt disc loss
|
62 |
+
* `d_loss.backward()`
|
63 |
+
* `d_opt_step()`
|
64 |
+
* `update_learning_rates()` -> update learning rates according to schedules defined in `opts.gen.opt` and `opts.dis.opt`
|
65 |
+
* `run_validation()`
|
66 |
+
* compute val losses
|
67 |
+
* `eval_images()` -> compute metrics
|
68 |
+
* `log_comet_images()` -> compute and upload inferences
|
69 |
+
* `save()`
|
70 |
+
|
71 |
+
### Resuming
|
72 |
+
|
73 |
+
Set `train.resume` to `True` in `opts.yaml` and specify where to load the weights:
|
74 |
+
|
75 |
+
Use a config's `load_path` namespace. It should have sub-keys `m`, `p` and `pm`:
|
76 |
+
|
77 |
+
```yaml
|
78 |
+
load_paths:
|
79 |
+
p: none # Painter weights
|
80 |
+
m: none # Masker weights
|
81 |
+
pm: none # Painter + Masker weights (single ckpt for both)
|
82 |
+
```
|
83 |
+
|
84 |
+
1. any path which leads to a dir will be loaded as `path / checkpoints / latest_ckpt.pth`
|
85 |
+
2. if you want to specify a specific checkpoint (not the latest), it MUST be a `.pth` file
|
86 |
+
3. resuming a `P` **OR** an `M` model, you may only specify 1 of `load_path.p` **OR** `load_path.m`.
|
87 |
+
You may also leave **BOTH** at `none`, in which case `output_path / checkpoints / latest_ckpt.pth`
|
88 |
+
will be used
|
89 |
+
4. resuming a P+M model, you may specify (`p` AND `m`) **OR** `pm` **OR** leave all at `none`,
|
90 |
+
in which case `output_path / checkpoints / latest_ckpt.pth` will be used to load from
|
91 |
+
a single checkpoint
|
92 |
+
|
93 |
+
### Generator
|
94 |
+
|
95 |
+
* **Encoder**:
|
96 |
+
|
97 |
+
`trainer.G.encoder` Deeplabv2 or v3-based encoder
|
98 |
+
* Code borrowed from
|
99 |
+
* https://github.com/valeoai/ADVENT/blob/master/advent/model/deeplabv2.py
|
100 |
+
* https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes
|
101 |
+
|
102 |
+
* **Decoders**:
|
103 |
+
* `trainer.G.decoders["s"]` -> *Segmentation* -> DLV3+ architecture (ASPP + Decoder)
|
104 |
+
* `trainer.G.decoders["d"]` -> *Depth* -> ResBlocks + (Upsample + Conv)
|
105 |
+
* `trainer.G.decoders["m"]` -> *Mask* -> ResBlocks + (Upsample + Conv) -> Binary mask: 1 = water should be there
|
106 |
+
* `trainer.G.mask()` predicts a mask and optionally applies `sigmoid` from an `x` input or a `z` input
|
107 |
+
|
108 |
+
* **Painter**: `trainer.G.painter` -> [GauGAN SPADE-based](https://github.com/NVlabs/SPADE)
|
109 |
+
* input = masked image
|
110 |
+
* `trainer.G.paint(m, x)` higher level function which takes care of masking
|
111 |
+
* If `opts.gen.p.paste_original_content` the painter should only create water and not reconstruct outside the mask: the output of `paint()` is `painted * m + x * (1 - m)`
|
112 |
+
|
113 |
+
High level methods of interest:
|
114 |
+
|
115 |
+
* `trainer.infer_all()` creates a dictionary of events with keys `flood` `wildfire` and `smog`. Can take in a single image or a batch, of numpy arrays or torch tensors, on CPU/GPU/TPU. This method calls, amongst others:
|
116 |
+
* `trainer.G.encode()` to compute the shared latent vector `z`
|
117 |
+
* `trainer.G.mask(z=z)` to infer the mask
|
118 |
+
* `trainer.compute_fire(x, segmentation)` to create a wildfire image from `x` and inferred segmentation
|
119 |
+
* `trainer.compute_smog(x, depth)` to create a smog image from `x` and inferred depth
|
120 |
+
* `trainer.compute_flood(x, mask)` to create a flood image from `x` and inferred mask using the painter (`trainer.G.paint(m, x)`)
|
121 |
+
* `Trainer.resume_from_path()` static method to resume a trainer from a path
|
122 |
+
|
123 |
+
### Discriminator
|
124 |
+
|
125 |
+
## updates
|
126 |
+
|
127 |
+
multi-batch:
|
128 |
+
|
129 |
+
```
|
130 |
+
multi_domain_batch = {"rf: batch0, "r": batch1, "s": batch2}
|
131 |
+
```
|
132 |
+
|
133 |
+
## interfaces
|
134 |
+
|
135 |
+
### batches
|
136 |
+
```python
|
137 |
+
batch = Dict({
|
138 |
+
"data": {
|
139 |
+
"d": depthmap,,
|
140 |
+
"s": segmentation_map,
|
141 |
+
"m": binary_mask
|
142 |
+
"x": real_flooded_image,
|
143 |
+
},
|
144 |
+
"paths":{
|
145 |
+
same_keys: path_to_file
|
146 |
+
}
|
147 |
+
"domain": list(rf | r | s),
|
148 |
+
"mode": list(train | val)
|
149 |
+
})
|
150 |
+
```
|
151 |
+
|
152 |
+
### data
|
153 |
+
|
154 |
+
#### json files
|
155 |
+
|
156 |
+
| name | domain | description | author |
|
157 |
+
| :--------------------------------------------- | :----: | :------------------------------------------------------------------------- | :-------: |
|
158 |
+
| **train_r_full.json, val_r_full.json** | r | MiDaS+ Segmentation pseudo-labels .pt (HRNet + Cityscapes) | Mélisande |
|
159 |
+
| **train_s_full.json, val_s_full.json** | s | Simulated data from Unity11k urban + Unity suburban dataset | *** |
|
160 |
+
| train_s_nofences.json, val_s_nofences.json | s | Simulated data from Unity11k urban + Unity suburban dataset without fences | Alexia |
|
161 |
+
| train_r_full_pl.json, val_r_full_pl.json | r | MegaDepth + Segmentation pseudo-labels .pt (HRNet + Cityscapes) | Alexia |
|
162 |
+
| train_r_full_midas.json, val_r_full_midas.json | r | MiDaS+ Segmentation (HRNet + Cityscapes) | Mélisande |
|
163 |
+
| train_r_full_old.json, val_r_full_old.json | r | MegaDepth+ Segmentation (HRNet + Cityscapes) | *** |
|
164 |
+
| train_r_nopeople.json, val_r_nopeople.json | r | Same training data as above with people removed | Sasha |
|
165 |
+
| train_rf_with_sim.json | rf | Doubled train_rf's size with sim data (randomly chosen) | Victor |
|
166 |
+
| train_rf.json | rf | UPDATE (12/12/20): added 50 ims & masks from ADE20K Outdoors | Victor |
|
167 |
+
| train_allres.json, val_allres.json | rf | includes both lowres and highres from ORCA_water_seg | Tianyu |
|
168 |
+
| train_highres_only.json, val_highres_only.json | rf | includes only highres from ORCA_water_seg | Tianyu |
|
169 |
+
|
170 |
+
|
171 |
+
```yaml
|
172 |
+
# data file ; one for each r|s
|
173 |
+
- x: /path/to/image
|
174 |
+
m: /path/to/mask
|
175 |
+
s: /path/to/segmentation map
|
176 |
+
- x: /path/to/another image
|
177 |
+
d: /path/to/depth map
|
178 |
+
m: /path/to/mask
|
179 |
+
s: /path/to/segmentation map
|
180 |
+
- x: ...
|
181 |
+
```
|
182 |
+
|
183 |
+
or
|
184 |
+
|
185 |
+
```json
|
186 |
+
[
|
187 |
+
{
|
188 |
+
"x": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000005.jpg",
|
189 |
+
"s": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000005.npy",
|
190 |
+
"d": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000005_depth.jpg"
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"x": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000006.jpg",
|
194 |
+
"s": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000006.npy",
|
195 |
+
"d": "/Users/victor/Documents/ccai/github/climategan/example_data/gsv_000006_depth.jpg"
|
196 |
+
}
|
197 |
+
]
|
198 |
+
```
|
199 |
+
|
200 |
+
The json files used are located at `/network/tmp1/ccai/data/climategan/`. In the basenames, `_s` denotes simulated domain data and `_r` real domain data.
|
201 |
+
The `base` folder contains json files with paths to images (`"x"`key) and masks (taken as ground truth for the area that should be flooded, `"m"` key).
|
202 |
+
The `seg` folder contains json files and keys `"x"`, `"m"` and `"s"` (segmentation) for each image.
|
203 |
+
|
204 |
+
|
205 |
+
loaders
|
206 |
+
|
207 |
+
```
|
208 |
+
loaders = Dict({
|
209 |
+
train: { r: loader, s: loader},
|
210 |
+
val: { r: loader, s: loader}
|
211 |
+
})
|
212 |
+
```
|
213 |
+
|
214 |
+
### losses
|
215 |
+
|
216 |
+
`trainer.losses` is a dictionary mapping to loss functions to optimize for the 3 main parts of the architecture: generator `G`, discriminators `D`:
|
217 |
+
|
218 |
+
```python
|
219 |
+
trainer.losses = {
|
220 |
+
"G":{ # generator
|
221 |
+
"gan": { # gan loss from the discriminators
|
222 |
+
"a": GANLoss, # adaptation decoder
|
223 |
+
"t": GANLoss # translation decoder
|
224 |
+
},
|
225 |
+
"cycle": { # cycle-consistency loss
|
226 |
+
"a": l1 | l2,,
|
227 |
+
"t": l1 | l2,
|
228 |
+
},
|
229 |
+
"auto": { # auto-encoding loss a.k.a. reconstruction loss
|
230 |
+
"a": l1 | l2,
|
231 |
+
"t": l1 | l2
|
232 |
+
},
|
233 |
+
"tasks": { # specific losses for each auxillary task
|
234 |
+
"d": func, # depth estimation
|
235 |
+
"h": func, # height estimation
|
236 |
+
"s": cross_entropy_2d, # segmentation
|
237 |
+
"w": func, # water generation
|
238 |
+
},
|
239 |
+
"classifier": l1 | l2 | CE # loss from fooling the classifier
|
240 |
+
},
|
241 |
+
"D": GANLoss, # discriminator losses from the generator and true data
|
242 |
+
"C": l1 | l2 | CE # classifier should predict the right 1-h vector [rf, rn, sf, sn]
|
243 |
+
}
|
244 |
+
```
|
245 |
+
|
246 |
+
## Logging on comet
|
247 |
+
|
248 |
+
Comet.ml will look for api keys in the following order: argument to the `Experiment(api_key=...)` call, `COMET_API_KEY` environment variable, `.comet.config` file in the current working directory, `.comet.config` in the current user's home directory.
|
249 |
+
|
250 |
+
If your not managing several comet accounts at the same time, I recommend putting `.comet.config` in your home as such:
|
251 |
+
|
252 |
+
```
|
253 |
+
[comet]
|
254 |
+
api_key=<api_key>
|
255 |
+
workspace=vict0rsch
|
256 |
+
rest_api_key=<rest_api_key>
|
257 |
+
```
|
258 |
+
|
259 |
+
### Tests
|
260 |
+
|
261 |
+
Run tests by executing `python test_trainer.py`. You can add `--no_delete` not to delete the comet experiment at exit and inspect uploads.
|
262 |
+
|
263 |
+
Write tests as scenarios by adding to the list `test_scenarios` in the file. A scenario is a dict of overrides over the base opts in `shared/trainer/defaults.yaml`. You can create special flags for the scenario by adding keys which start with `__`. For instance, `__doc` is a mandatory key in any scenario describing it succinctly.
|
264 |
+
|
265 |
+
## Resources
|
266 |
+
|
267 |
+
[Tricks and Tips for Training a GAN](https://chloes-dl.com/2019/11/19/tricks-and-tips-for-training-a-gan/)
|
268 |
+
[GAN Hacks](https://github.com/soumith/ganhacks)
|
269 |
+
[Keep Calm and train a GAN. Pitfalls and Tips on training Generative Adversarial Networks](https://medium.com/@utk.is.here/keep-calm-and-train-a-gan-pitfalls-and-tips-on-training-generative-adversarial-networks-edd529764aa9)
|
270 |
+
|
271 |
+
## Example
|
272 |
+
|
273 |
+
**Inference: computing floods**
|
274 |
+
|
275 |
+
```python
|
276 |
+
from pathlib import Path
|
277 |
+
from skimage.io import imsave
|
278 |
+
from tqdm import tqdm
|
279 |
+
|
280 |
+
from climategan.trainer import Trainer
|
281 |
+
from climategan.utils import find_images
|
282 |
+
from climategan.tutils import tensor_ims_to_np_uint8s
|
283 |
+
from climategan.transforms import PrepareInference
|
284 |
+
|
285 |
+
|
286 |
+
model_path = "some/path/to/output/folder" # not .ckpt
|
287 |
+
input_folder = "path/to/a/folder/with/images"
|
288 |
+
output_path = "path/where/images/will/be/written"
|
289 |
+
|
290 |
+
# resume trainer
|
291 |
+
trainer = Trainer.resume_from_path(model_path, new_exp=None, inference=True)
|
292 |
+
|
293 |
+
# find paths for all images in the input folder. There is a recursive option.
|
294 |
+
im_paths = sorted(find_images(input_folder), key=lambda x: x.name)
|
295 |
+
|
296 |
+
# Load images into tensors
|
297 |
+
# * smaller side resized to 640 - keeping aspect ratio
|
298 |
+
# * then longer side is cropped in the center
|
299 |
+
# * result is a 1x3x640x640 float tensor in [-1; 1]
|
300 |
+
xs = PrepareInference()(im_paths)
|
301 |
+
|
302 |
+
# send to device
|
303 |
+
xs = [x.to(trainer.device) for x in xs]
|
304 |
+
|
305 |
+
# compute flood
|
306 |
+
# * compute mask
|
307 |
+
# * binarize mask if bin_value > 0
|
308 |
+
# * paint x using this mask
|
309 |
+
ys = [trainer.compute_flood(x, bin_value=0.5) for x in tqdm(xs)]
|
310 |
+
|
311 |
+
# convert 1x3x640x640 float tensors in [-1; 1] into 640x640x3 numpy arrays in [0, 255]
|
312 |
+
np_ys = [tensor_ims_to_np_uint8s(y) for y in tqdm(ys)]
|
313 |
+
|
314 |
+
# write images
|
315 |
+
for i, n in tqdm(zip(im_paths, np_ys), total=len(im_paths)):
|
316 |
+
imsave(Path(output_path) / i.name, n)
|
317 |
+
```
|
318 |
+
|
319 |
+
## Release process
|
320 |
+
|
321 |
+
In the `release/` folder
|
322 |
+
* create a `model/` folder
|
323 |
+
* create folders `model/masker/` and `model/painter/`
|
324 |
+
* add the climategan code in `release/`: `git clone [email protected]:cc-ai/climategan.git`
|
325 |
+
* move the code to `release/`: `cp climategan/* . && rm -rf climategan`
|
326 |
+
* update `model/masker/opts/events` with `events:` from `shared/trainer/opts.yaml`
|
327 |
+
* update `model/masker/opts/val.val_painter` to `"model/painter/checkpoints/latest_ckpt.pth"`
|
328 |
+
* update `model/masker/opts/load_paths.m` to `"model/masker/checkpoints/latest_ckpt.pth"`
|
apply_events.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
def parse_args():
|
5 |
+
parser = argparse.ArgumentParser()
|
6 |
+
parser.add_argument(
|
7 |
+
"-b",
|
8 |
+
"--batch_size",
|
9 |
+
type=int,
|
10 |
+
default=4,
|
11 |
+
help="Batch size to process input images to events. Defaults to 4",
|
12 |
+
)
|
13 |
+
parser.add_argument(
|
14 |
+
"-i",
|
15 |
+
"--images_paths",
|
16 |
+
type=str,
|
17 |
+
required=True,
|
18 |
+
help="Path to a directory with image files",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"-o",
|
22 |
+
"--output_path",
|
23 |
+
type=str,
|
24 |
+
default=None,
|
25 |
+
help="Path to a directory were events should be written. "
|
26 |
+
+ "Will NOT write anything to disk if this flag is not used.",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"-s",
|
30 |
+
"--save_input",
|
31 |
+
action="store_true",
|
32 |
+
default=False,
|
33 |
+
help="Binary flag to include the input image to the model (after crop and"
|
34 |
+
+ " resize) in the images written or uploaded (depending on saving options.)",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"-r",
|
38 |
+
"--resume_path",
|
39 |
+
type=str,
|
40 |
+
default=None,
|
41 |
+
help="Path to a directory containing the trainer to resume."
|
42 |
+
+ " In particular it must contain `opts.yam` and `checkpoints/`."
|
43 |
+
+ " Typically this points to a Masker, which holds the path to a"
|
44 |
+
+ " Painter in its opts",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--no_time",
|
48 |
+
action="store_true",
|
49 |
+
default=False,
|
50 |
+
help="Binary flag to prevent the timing of operations.",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"-f",
|
54 |
+
"--flood_mask_binarization",
|
55 |
+
type=float,
|
56 |
+
default=0.5,
|
57 |
+
help="Value to use to binarize masks (mask > value). "
|
58 |
+
+ "Set to -1 to use soft masks (not binarized). Defaults to 0.5.",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"-t",
|
62 |
+
"--target_size",
|
63 |
+
type=int,
|
64 |
+
default=640,
|
65 |
+
help="Output image size (when not using `keep_ratio_128`): images are resized"
|
66 |
+
+ " such that their smallest side is `target_size` then cropped in the middle"
|
67 |
+
+ " of the largest side such that the resulting input image (and output images)"
|
68 |
+
+ " has height and width `target_size x target_size`. **Must** be a multiple of"
|
69 |
+
+ " 2^7=128 (up/downscaling inside the models). Defaults to 640.",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--half",
|
73 |
+
action="store_true",
|
74 |
+
default=False,
|
75 |
+
help="Binary flag to use half precision (float16). Defaults to False.",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"-n",
|
79 |
+
"--n_images",
|
80 |
+
default=-1,
|
81 |
+
type=int,
|
82 |
+
help="Limit the number of images processed (if you have 100 images in "
|
83 |
+
+ "a directory but n is 10 then only the first 10 images will be loaded"
|
84 |
+
+ " for processing)",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--no_conf",
|
88 |
+
action="store_true",
|
89 |
+
default=False,
|
90 |
+
help="disable writing the apply_events hash and command in the output folder",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--overwrite",
|
94 |
+
action="store_true",
|
95 |
+
default=False,
|
96 |
+
help="Do not check for existing outdir, i.e. force overwrite"
|
97 |
+
+ " potentially existing files in the output path",
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--no_cloudy",
|
101 |
+
action="store_true",
|
102 |
+
default=False,
|
103 |
+
help="Prevent the use of the cloudy intermediate"
|
104 |
+
+ " image to create the flood image. Rendering will"
|
105 |
+
+ " be more colorful but may seem less realistic",
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--keep_ratio_128",
|
109 |
+
action="store_true",
|
110 |
+
default=False,
|
111 |
+
help="When loading the input images, resize and crop them in order for their "
|
112 |
+
+ "dimensions to match the closest multiples"
|
113 |
+
+ " of 128. Will force a batch size of 1 since images"
|
114 |
+
+ " now have different dimensions. "
|
115 |
+
+ "Use --max_im_width to cap the resulting dimensions.",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--fuse",
|
119 |
+
action="store_true",
|
120 |
+
default=False,
|
121 |
+
help="Use batch norm fusion to speed up inference",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--save_masks",
|
125 |
+
action="store_true",
|
126 |
+
default=False,
|
127 |
+
help="Save output masks along events",
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"-m",
|
131 |
+
"--max_im_width",
|
132 |
+
type=int,
|
133 |
+
default=-1,
|
134 |
+
help="When using --keep_ratio_128, some images may still be too large. Use "
|
135 |
+
+ "--max_im_width to cap the resized image's width. Defaults to -1 (no cap).",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--upload",
|
139 |
+
action="store_true",
|
140 |
+
help="Upload to comet.ml in a project called `climategan-apply`",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--zip_outdir",
|
144 |
+
"-z",
|
145 |
+
action="store_true",
|
146 |
+
help="Zip the output directory as '{outdir.parent}/{outdir.name}.zip'",
|
147 |
+
)
|
148 |
+
return parser.parse_args()
|
149 |
+
|
150 |
+
|
151 |
+
args = parse_args()
|
152 |
+
|
153 |
+
|
154 |
+
print("\n• Imports\n")
|
155 |
+
import time
|
156 |
+
|
157 |
+
import_time = time.time()
|
158 |
+
import sys
|
159 |
+
import shutil
|
160 |
+
from collections import OrderedDict
|
161 |
+
from pathlib import Path
|
162 |
+
|
163 |
+
import comet_ml # noqa: F401
|
164 |
+
import torch
|
165 |
+
import numpy as np
|
166 |
+
import skimage.io as io
|
167 |
+
from skimage.color import rgba2rgb
|
168 |
+
from skimage.transform import resize
|
169 |
+
from tqdm import tqdm
|
170 |
+
|
171 |
+
from climategan.trainer import Trainer
|
172 |
+
from climategan.bn_fusion import bn_fuse
|
173 |
+
from climategan.tutils import print_num_parameters
|
174 |
+
from climategan.utils import Timer, find_images, get_git_revision_hash, to_128, resolve
|
175 |
+
|
176 |
+
import_time = time.time() - import_time
|
177 |
+
|
178 |
+
|
179 |
+
def to_m1_p1(img, i):
|
180 |
+
"""
|
181 |
+
rescales a [0, 1] image to [-1, +1]
|
182 |
+
|
183 |
+
Args:
|
184 |
+
img (np.array): float32 numpy array of an image in [0, 1]
|
185 |
+
i (int): Index of the image being rescaled
|
186 |
+
|
187 |
+
Raises:
|
188 |
+
ValueError: If the image is not in [0, 1]
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
np.array(np.float32): array in [-1, +1]
|
192 |
+
"""
|
193 |
+
if img.min() >= 0 and img.max() <= 1:
|
194 |
+
return (img.astype(np.float32) - 0.5) * 2
|
195 |
+
raise ValueError(f"Data range mismatch for image {i} : ({img.min()}, {img.max()})")
|
196 |
+
|
197 |
+
|
198 |
+
def uint8(array):
|
199 |
+
"""
|
200 |
+
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
|
201 |
+
|
202 |
+
Args:
|
203 |
+
array (np.array): array to modify
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
np.array(np.uint8): converted array
|
207 |
+
"""
|
208 |
+
return array.astype(np.uint8)
|
209 |
+
|
210 |
+
|
211 |
+
def resize_and_crop(img, to=640):
|
212 |
+
"""
|
213 |
+
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
|
214 |
+
is `to`, then crops this resized image in its center so that the output is `to x to`
|
215 |
+
without aspect ratio distortion
|
216 |
+
|
217 |
+
Args:
|
218 |
+
img (np.array): np.uint8 255 image
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
np.array: [0, 1] np.float32 image
|
222 |
+
"""
|
223 |
+
# resize keeping aspect ratio: smallest dim is 640
|
224 |
+
h, w = img.shape[:2]
|
225 |
+
if h < w:
|
226 |
+
size = (to, int(to * w / h))
|
227 |
+
else:
|
228 |
+
size = (int(to * h / w), to)
|
229 |
+
|
230 |
+
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
|
231 |
+
r_img = uint8(r_img)
|
232 |
+
|
233 |
+
# crop in the center
|
234 |
+
H, W = r_img.shape[:2]
|
235 |
+
|
236 |
+
top = (H - to) // 2
|
237 |
+
left = (W - to) // 2
|
238 |
+
|
239 |
+
rc_img = r_img[top : top + to, left : left + to, :]
|
240 |
+
|
241 |
+
return rc_img / 255.0
|
242 |
+
|
243 |
+
|
244 |
+
def print_time(text, time_series, purge=-1):
|
245 |
+
"""
|
246 |
+
Print a timeseries's mean and std with a label
|
247 |
+
|
248 |
+
Args:
|
249 |
+
text (str): label of the time series
|
250 |
+
time_series (list): list of timings
|
251 |
+
purge (int, optional): ignore first n values of time series. Defaults to -1.
|
252 |
+
"""
|
253 |
+
if not time_series:
|
254 |
+
return
|
255 |
+
|
256 |
+
if purge > 0 and len(time_series) > purge:
|
257 |
+
time_series = time_series[purge:]
|
258 |
+
|
259 |
+
m = np.mean(time_series)
|
260 |
+
s = np.std(time_series)
|
261 |
+
|
262 |
+
print(
|
263 |
+
f"{text.capitalize() + ' ':.<26} {m:.5f}"
|
264 |
+
+ (f" +/- {s:.5f}" if len(time_series) > 1 else "")
|
265 |
+
)
|
266 |
+
|
267 |
+
|
268 |
+
def print_store(store, purge=-1):
|
269 |
+
"""
|
270 |
+
Pretty-print time series store
|
271 |
+
|
272 |
+
Args:
|
273 |
+
store (dict): maps string keys to lists of times
|
274 |
+
purge (int, optional): ignore first n values of time series. Defaults to -1.
|
275 |
+
"""
|
276 |
+
singles = OrderedDict({k: v for k, v in store.items() if len(v) == 1})
|
277 |
+
multiples = OrderedDict({k: v for k, v in store.items() if len(v) > 1})
|
278 |
+
empties = {k: v for k, v in store.items() if len(v) == 0}
|
279 |
+
|
280 |
+
if empties:
|
281 |
+
print("Ignoring empty stores ", ", ".join(empties.keys()))
|
282 |
+
print()
|
283 |
+
|
284 |
+
for k in singles:
|
285 |
+
print_time(k, singles[k], purge)
|
286 |
+
|
287 |
+
print()
|
288 |
+
print("Unit: s/batch")
|
289 |
+
for k in multiples:
|
290 |
+
print_time(k, multiples[k], purge)
|
291 |
+
print()
|
292 |
+
|
293 |
+
|
294 |
+
def write_apply_config(out):
|
295 |
+
"""
|
296 |
+
Saves the args to `apply_events.py` in a text file for future reference
|
297 |
+
"""
|
298 |
+
cwd = Path.cwd().expanduser().resolve()
|
299 |
+
command = f"cd {str(cwd)}\n"
|
300 |
+
command += " ".join(sys.argv)
|
301 |
+
git_hash = get_git_revision_hash()
|
302 |
+
with (out / "command.txt").open("w") as f:
|
303 |
+
f.write(command)
|
304 |
+
with (out / "hash.txt").open("w") as f:
|
305 |
+
f.write(git_hash)
|
306 |
+
|
307 |
+
|
308 |
+
def get_outdir_name(half, keep_ratio, max_im_width, target_size, bin_value, cloudy):
|
309 |
+
"""
|
310 |
+
Create the output directory's name based on uer-provided arguments
|
311 |
+
"""
|
312 |
+
name_items = []
|
313 |
+
if half:
|
314 |
+
name_items.append("half")
|
315 |
+
if keep_ratio:
|
316 |
+
name_items.append("AR")
|
317 |
+
if max_im_width and keep_ratio:
|
318 |
+
name_items.append(f"{max_im_width}")
|
319 |
+
if target_size and not keep_ratio:
|
320 |
+
name_items.append("S")
|
321 |
+
name_items.append(f"{target_size}")
|
322 |
+
if bin_value != 0.5:
|
323 |
+
name_items.append(f"bin{bin_value}")
|
324 |
+
if not cloudy:
|
325 |
+
name_items.append("no_cloudy")
|
326 |
+
|
327 |
+
return "-".join(name_items)
|
328 |
+
|
329 |
+
|
330 |
+
def make_outdir(
|
331 |
+
outdir, overwrite, half, keep_ratio, max_im_width, target_size, bin_value, cloudy
|
332 |
+
):
|
333 |
+
"""
|
334 |
+
Creates the output directory if it does not exist. If it does exist,
|
335 |
+
prompts the user for confirmation (except if `overwrite` is True).
|
336 |
+
If the output directory's name is "_auto_" then it is created as:
|
337 |
+
outdir.parent / get_outdir_name(...)
|
338 |
+
"""
|
339 |
+
if outdir.name == "_auto_":
|
340 |
+
outdir = outdir.parent / get_outdir_name(
|
341 |
+
half, keep_ratio, max_im_width, target_size, bin_value, cloudy
|
342 |
+
)
|
343 |
+
if outdir.exists() and not overwrite:
|
344 |
+
print(
|
345 |
+
f"\nWARNING: outdir ({str(outdir)}) already exists."
|
346 |
+
+ " Files with existing names will be overwritten"
|
347 |
+
)
|
348 |
+
if "n" in input(">>> Continue anyway? [y / n] (default: y) : "):
|
349 |
+
print("Interrupting execution from user input.")
|
350 |
+
sys.exit()
|
351 |
+
print()
|
352 |
+
outdir.mkdir(exist_ok=True, parents=True)
|
353 |
+
return outdir
|
354 |
+
|
355 |
+
|
356 |
+
def get_time_stores(import_time):
|
357 |
+
return OrderedDict(
|
358 |
+
{
|
359 |
+
"imports": [import_time],
|
360 |
+
"setup": [],
|
361 |
+
"data pre-processing": [],
|
362 |
+
"encode": [],
|
363 |
+
"mask": [],
|
364 |
+
"flood": [],
|
365 |
+
"depth": [],
|
366 |
+
"segmentation": [],
|
367 |
+
"smog": [],
|
368 |
+
"wildfire": [],
|
369 |
+
"all events": [],
|
370 |
+
"numpy": [],
|
371 |
+
"inference on all images": [],
|
372 |
+
"write": [],
|
373 |
+
}
|
374 |
+
)
|
375 |
+
|
376 |
+
|
377 |
+
if __name__ == "__main__":
|
378 |
+
|
379 |
+
# -----------------------------------------
|
380 |
+
# ----- Initialize script variables -----
|
381 |
+
# -----------------------------------------
|
382 |
+
print(
|
383 |
+
"• Using args\n\n"
|
384 |
+
+ "\n".join(["{:25}: {}".format(k, v) for k, v in vars(args).items()]),
|
385 |
+
)
|
386 |
+
|
387 |
+
batch_size = args.batch_size
|
388 |
+
bin_value = args.flood_mask_binarization
|
389 |
+
cloudy = not args.no_cloudy
|
390 |
+
fuse = args.fuse
|
391 |
+
half = args.half
|
392 |
+
save_masks = args.save_masks
|
393 |
+
images_paths = resolve(args.images_paths)
|
394 |
+
keep_ratio = args.keep_ratio_128
|
395 |
+
max_im_width = args.max_im_width
|
396 |
+
n_images = args.n_images
|
397 |
+
outdir = resolve(args.output_path) if args.output_path is not None else None
|
398 |
+
resume_path = args.resume_path
|
399 |
+
target_size = args.target_size
|
400 |
+
time_inference = not args.no_time
|
401 |
+
upload = args.upload
|
402 |
+
zip_outdir = args.zip_outdir
|
403 |
+
|
404 |
+
# -------------------------------------
|
405 |
+
# ----- Validate size arguments -----
|
406 |
+
# -------------------------------------
|
407 |
+
if keep_ratio:
|
408 |
+
if target_size != 640:
|
409 |
+
print(
|
410 |
+
"\nWARNING: using --keep_ratio_128 overwrites target_size"
|
411 |
+
+ " which is ignored."
|
412 |
+
)
|
413 |
+
if batch_size != 1:
|
414 |
+
print("\nWARNING: batch_size overwritten to 1 when using keep_ratio_128")
|
415 |
+
batch_size = 1
|
416 |
+
if max_im_width > 0 and max_im_width % 128 != 0:
|
417 |
+
new_im_width = int(max_im_width / 128) * 128
|
418 |
+
print("\nWARNING: max_im_width should be <0 or a multiple of 128.")
|
419 |
+
print(
|
420 |
+
" Was {} but is now overwritten to {}".format(
|
421 |
+
max_im_width, new_im_width
|
422 |
+
)
|
423 |
+
)
|
424 |
+
max_im_width = new_im_width
|
425 |
+
else:
|
426 |
+
if target_size % 128 != 0:
|
427 |
+
print(f"\nWarning: target size {target_size} is not a multiple of 128.")
|
428 |
+
target_size = target_size - (target_size % 128)
|
429 |
+
print(f"Setting target_size to {target_size}.")
|
430 |
+
|
431 |
+
# -------------------------------------
|
432 |
+
# ----- Create output directory -----
|
433 |
+
# -------------------------------------
|
434 |
+
if outdir is not None:
|
435 |
+
outdir = make_outdir(
|
436 |
+
outdir,
|
437 |
+
args.overwrite,
|
438 |
+
half,
|
439 |
+
keep_ratio,
|
440 |
+
max_im_width,
|
441 |
+
target_size,
|
442 |
+
bin_value,
|
443 |
+
cloudy,
|
444 |
+
)
|
445 |
+
|
446 |
+
# -------------------------------
|
447 |
+
# ----- Create time store -----
|
448 |
+
# -------------------------------
|
449 |
+
stores = get_time_stores(import_time)
|
450 |
+
|
451 |
+
# -----------------------------------
|
452 |
+
# ----- Load Trainer instance -----
|
453 |
+
# -----------------------------------
|
454 |
+
with Timer(store=stores.get("setup", []), ignore=time_inference):
|
455 |
+
print("\n• Initializing trainer\n")
|
456 |
+
torch.set_grad_enabled(False)
|
457 |
+
trainer = Trainer.resume_from_path(
|
458 |
+
resume_path,
|
459 |
+
setup=True,
|
460 |
+
inference=True,
|
461 |
+
new_exp=None,
|
462 |
+
)
|
463 |
+
print()
|
464 |
+
print_num_parameters(trainer, True)
|
465 |
+
if fuse:
|
466 |
+
trainer.G = bn_fuse(trainer.G)
|
467 |
+
if half:
|
468 |
+
trainer.G.half()
|
469 |
+
|
470 |
+
# --------------------------------------------
|
471 |
+
# ----- Read data from input directory -----
|
472 |
+
# --------------------------------------------
|
473 |
+
print("\n• Reading & Pre-processing Data\n")
|
474 |
+
|
475 |
+
# find all images
|
476 |
+
data_paths = find_images(images_paths)
|
477 |
+
base_data_paths = data_paths
|
478 |
+
# filter images
|
479 |
+
if 0 < n_images < len(data_paths):
|
480 |
+
data_paths = data_paths[:n_images]
|
481 |
+
# repeat data
|
482 |
+
elif n_images > len(data_paths):
|
483 |
+
repeats = n_images // len(data_paths) + 1
|
484 |
+
data_paths = base_data_paths * repeats
|
485 |
+
data_paths = data_paths[:n_images]
|
486 |
+
|
487 |
+
with Timer(store=stores.get("data pre-processing", []), ignore=time_inference):
|
488 |
+
# read images to numpy arrays
|
489 |
+
data = [io.imread(str(d)) for d in data_paths]
|
490 |
+
# rgba to rgb
|
491 |
+
data = [im if im.shape[-1] == 3 else uint8(rgba2rgb(im) * 255) for im in data]
|
492 |
+
# resize images to target_size or
|
493 |
+
if keep_ratio:
|
494 |
+
# to closest multiples of 128 <= max_im_width, keeping aspect ratio
|
495 |
+
new_sizes = [to_128(d, max_im_width) for d in data]
|
496 |
+
data = [resize(d, ns, anti_aliasing=True) for d, ns in zip(data, new_sizes)]
|
497 |
+
else:
|
498 |
+
# to args.target_size
|
499 |
+
data = [resize_and_crop(d, target_size) for d in data]
|
500 |
+
new_sizes = [(target_size, target_size) for _ in data]
|
501 |
+
# resize() produces [0, 1] images, rescale to [-1, 1]
|
502 |
+
data = [to_m1_p1(d, i) for i, d in enumerate(data)]
|
503 |
+
|
504 |
+
n_batchs = len(data) // batch_size
|
505 |
+
if len(data) % batch_size != 0:
|
506 |
+
n_batchs += 1
|
507 |
+
|
508 |
+
print("Found", len(base_data_paths), "images. Inferring on", len(data), "images.")
|
509 |
+
|
510 |
+
# --------------------------------------------
|
511 |
+
# ----- Batch-process images to events -----
|
512 |
+
# --------------------------------------------
|
513 |
+
print(f"\n• Using device {str(trainer.device)}\n")
|
514 |
+
|
515 |
+
all_events = []
|
516 |
+
|
517 |
+
with Timer(store=stores.get("inference on all images", []), ignore=time_inference):
|
518 |
+
for b in tqdm(range(n_batchs), desc="Infering events", unit="batch"):
|
519 |
+
|
520 |
+
images = data[b * batch_size : (b + 1) * batch_size]
|
521 |
+
if not images:
|
522 |
+
continue
|
523 |
+
|
524 |
+
# concatenate images in a batch batch_size x height x width x 3
|
525 |
+
images = np.stack(images)
|
526 |
+
# Retreive numpy events as a dict {event: array[BxHxWxC]}
|
527 |
+
events = trainer.infer_all(
|
528 |
+
images,
|
529 |
+
numpy=True,
|
530 |
+
stores=stores,
|
531 |
+
bin_value=bin_value,
|
532 |
+
half=half,
|
533 |
+
cloudy=cloudy,
|
534 |
+
return_masks=save_masks,
|
535 |
+
)
|
536 |
+
|
537 |
+
# save resized and cropped image
|
538 |
+
if args.save_input:
|
539 |
+
events["input"] = uint8((images + 1) / 2 * 255)
|
540 |
+
|
541 |
+
# store events to write after inference loop
|
542 |
+
all_events.append(events)
|
543 |
+
|
544 |
+
# --------------------------------------------
|
545 |
+
# ----- Save (write/upload) inferences -----
|
546 |
+
# --------------------------------------------
|
547 |
+
if outdir is not None or upload:
|
548 |
+
|
549 |
+
if upload:
|
550 |
+
print("\n• Creating comet Experiment")
|
551 |
+
exp = comet_ml.Experiment(project_name="climategan-apply")
|
552 |
+
exp.log_parameters(vars(args))
|
553 |
+
|
554 |
+
# --------------------------------------------------------------
|
555 |
+
# ----- Change inferred data structure to a list of dicts -----
|
556 |
+
# --------------------------------------------------------------
|
557 |
+
to_write = []
|
558 |
+
events_names = list(all_events[0].keys())
|
559 |
+
for events_data in all_events:
|
560 |
+
n_ims = len(events_data[events_names[0]])
|
561 |
+
for i in range(n_ims):
|
562 |
+
item = {event: events_data[event][i] for event in events_names}
|
563 |
+
to_write.append(item)
|
564 |
+
|
565 |
+
progress_bar_desc = ""
|
566 |
+
if outdir is not None:
|
567 |
+
print("\n• Output directory:\n")
|
568 |
+
print(str(outdir), "\n")
|
569 |
+
if upload:
|
570 |
+
progress_bar_desc = "Writing & Uploading events"
|
571 |
+
else:
|
572 |
+
progress_bar_desc = "Writing events"
|
573 |
+
else:
|
574 |
+
if upload:
|
575 |
+
progress_bar_desc = "Uploading events"
|
576 |
+
|
577 |
+
# ------------------------------------
|
578 |
+
# ----- Save individual images -----
|
579 |
+
# ------------------------------------
|
580 |
+
with Timer(store=stores.get("write", []), ignore=time_inference):
|
581 |
+
|
582 |
+
# for each image
|
583 |
+
for t, event_dict in tqdm(
|
584 |
+
enumerate(to_write),
|
585 |
+
desc=progress_bar_desc,
|
586 |
+
unit="input image",
|
587 |
+
total=len(to_write),
|
588 |
+
):
|
589 |
+
|
590 |
+
idx = t % len(base_data_paths)
|
591 |
+
stem = Path(data_paths[idx]).stem
|
592 |
+
width = new_sizes[idx][1]
|
593 |
+
|
594 |
+
if keep_ratio:
|
595 |
+
ar = "_AR"
|
596 |
+
else:
|
597 |
+
ar = ""
|
598 |
+
|
599 |
+
# for each event type
|
600 |
+
event_bar = tqdm(
|
601 |
+
enumerate(event_dict.items()),
|
602 |
+
leave=False,
|
603 |
+
total=len(events_names),
|
604 |
+
unit="event",
|
605 |
+
)
|
606 |
+
for e, (event, im_data) in event_bar:
|
607 |
+
event_bar.set_description(
|
608 |
+
f" {event.capitalize():<{len(progress_bar_desc) - 2}}"
|
609 |
+
)
|
610 |
+
|
611 |
+
if args.no_cloudy:
|
612 |
+
suffix = ar + "_no_cloudy"
|
613 |
+
else:
|
614 |
+
suffix = ar
|
615 |
+
|
616 |
+
im_path = Path(f"{stem}_{event}_{width}{suffix}.png")
|
617 |
+
|
618 |
+
if outdir is not None:
|
619 |
+
im_path = outdir / im_path
|
620 |
+
io.imsave(im_path, im_data)
|
621 |
+
|
622 |
+
if upload:
|
623 |
+
exp.log_image(im_data, name=im_path.name)
|
624 |
+
if zip_outdir:
|
625 |
+
print("\n• Zipping output directory... ", end="", flush=True)
|
626 |
+
archive_path = Path(shutil.make_archive(outdir.name, "zip", root_dir=outdir))
|
627 |
+
archive_path = archive_path.rename(outdir.parent / archive_path.name)
|
628 |
+
print("Done:\n")
|
629 |
+
print(str(archive_path))
|
630 |
+
|
631 |
+
# ---------------------------
|
632 |
+
# ----- Print timings -----
|
633 |
+
# ---------------------------
|
634 |
+
if time_inference:
|
635 |
+
print("\n• Timings\n")
|
636 |
+
print_store(stores)
|
637 |
+
|
638 |
+
# ---------------------------------------------
|
639 |
+
# ----- Save apply_events.py run config -----
|
640 |
+
# ---------------------------------------------
|
641 |
+
if not args.no_conf and outdir is not None:
|
642 |
+
write_apply_config(outdir)
|
climategan/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
import_module(f".{f.stem}", __package__)
|
6 |
+
for f in Path(__file__).parent.glob("*.py")
|
7 |
+
if "__" not in f.stem
|
8 |
+
]
|
9 |
+
del import_module, Path
|
climategan/blocks.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""File for all blocks which are parts of decoders
|
2 |
+
"""
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import climategan.strings as strings
|
8 |
+
from climategan.norms import SPADE, AdaptiveInstanceNorm2d, LayerNorm, SpectralNorm
|
9 |
+
|
10 |
+
|
11 |
+
class InterpolateNearest2d(nn.Module):
|
12 |
+
"""
|
13 |
+
Custom implementation of nn.Upsample because pytorch/xla
|
14 |
+
does not yet support scale_factor and needs to be provided with
|
15 |
+
the output_size
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, scale_factor=2):
|
19 |
+
"""
|
20 |
+
Create an InterpolateNearest2d module
|
21 |
+
|
22 |
+
Args:
|
23 |
+
scale_factor (int, optional): Output size multiplier. Defaults to 2.
|
24 |
+
"""
|
25 |
+
super().__init__()
|
26 |
+
self.scale_factor = scale_factor
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
"""
|
30 |
+
Interpolate x in "nearest" mode on its last 2 dimensions
|
31 |
+
|
32 |
+
Args:
|
33 |
+
x (torch.Tensor): input to interpolate
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
torch.Tensor: upsampled tensor with shape
|
37 |
+
(...x.shape, x.shape[-2] * scale_factor, x.shape[-1] * scale_factor)
|
38 |
+
"""
|
39 |
+
return F.interpolate(
|
40 |
+
x,
|
41 |
+
size=(x.shape[-2] * self.scale_factor, x.shape[-1] * self.scale_factor),
|
42 |
+
mode="nearest",
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
# -----------------------------------------
|
47 |
+
# ----- Generic Convolutional Block -----
|
48 |
+
# -----------------------------------------
|
49 |
+
class Conv2dBlock(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
input_dim,
|
53 |
+
output_dim,
|
54 |
+
kernel_size,
|
55 |
+
stride=1,
|
56 |
+
padding=0,
|
57 |
+
dilation=1,
|
58 |
+
norm="none",
|
59 |
+
activation="relu",
|
60 |
+
pad_type="zero",
|
61 |
+
bias=True,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.use_bias = bias
|
65 |
+
# initialize padding
|
66 |
+
if pad_type == "reflect":
|
67 |
+
self.pad = nn.ReflectionPad2d(padding)
|
68 |
+
elif pad_type == "replicate":
|
69 |
+
self.pad = nn.ReplicationPad2d(padding)
|
70 |
+
elif pad_type == "zero":
|
71 |
+
self.pad = nn.ZeroPad2d(padding)
|
72 |
+
else:
|
73 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
74 |
+
|
75 |
+
# initialize normalization
|
76 |
+
use_spectral_norm = False
|
77 |
+
if norm.startswith("spectral_"):
|
78 |
+
norm = norm.replace("spectral_", "")
|
79 |
+
use_spectral_norm = True
|
80 |
+
|
81 |
+
norm_dim = output_dim
|
82 |
+
if norm == "batch":
|
83 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
84 |
+
elif norm == "instance":
|
85 |
+
# self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
|
86 |
+
self.norm = nn.InstanceNorm2d(norm_dim)
|
87 |
+
elif norm == "layer":
|
88 |
+
self.norm = LayerNorm(norm_dim)
|
89 |
+
elif norm == "adain":
|
90 |
+
self.norm = AdaptiveInstanceNorm2d(norm_dim)
|
91 |
+
elif norm == "spectral" or norm.startswith("spectral_"):
|
92 |
+
self.norm = None # dealt with later in the code
|
93 |
+
elif norm == "none":
|
94 |
+
self.norm = None
|
95 |
+
else:
|
96 |
+
raise ValueError("Unsupported normalization: {}".format(norm))
|
97 |
+
|
98 |
+
# initialize activation
|
99 |
+
if activation == "relu":
|
100 |
+
self.activation = nn.ReLU(inplace=False)
|
101 |
+
elif activation == "lrelu":
|
102 |
+
self.activation = nn.LeakyReLU(0.2, inplace=False)
|
103 |
+
elif activation == "prelu":
|
104 |
+
self.activation = nn.PReLU()
|
105 |
+
elif activation == "selu":
|
106 |
+
self.activation = nn.SELU(inplace=False)
|
107 |
+
elif activation == "tanh":
|
108 |
+
self.activation = nn.Tanh()
|
109 |
+
elif activation == "sigmoid":
|
110 |
+
self.activation = nn.Sigmoid()
|
111 |
+
elif activation == "none":
|
112 |
+
self.activation = None
|
113 |
+
else:
|
114 |
+
raise ValueError("Unsupported activation: {}".format(activation))
|
115 |
+
|
116 |
+
# initialize convolution
|
117 |
+
if norm == "spectral" or use_spectral_norm:
|
118 |
+
self.conv = SpectralNorm(
|
119 |
+
nn.Conv2d(
|
120 |
+
input_dim,
|
121 |
+
output_dim,
|
122 |
+
kernel_size,
|
123 |
+
stride,
|
124 |
+
dilation=dilation,
|
125 |
+
bias=self.use_bias,
|
126 |
+
)
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
self.conv = nn.Conv2d(
|
130 |
+
input_dim,
|
131 |
+
output_dim,
|
132 |
+
kernel_size,
|
133 |
+
stride,
|
134 |
+
dilation=dilation,
|
135 |
+
bias=self.use_bias if norm != "batch" else False,
|
136 |
+
)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
x = self.conv(self.pad(x))
|
140 |
+
if self.norm is not None:
|
141 |
+
x = self.norm(x)
|
142 |
+
if self.activation is not None:
|
143 |
+
x = self.activation(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
def __str__(self):
|
147 |
+
return strings.conv2dblock(self)
|
148 |
+
|
149 |
+
|
150 |
+
# -----------------------------
|
151 |
+
# ----- Residual Blocks -----
|
152 |
+
# -----------------------------
|
153 |
+
class ResBlocks(nn.Module):
|
154 |
+
"""
|
155 |
+
From https://github.com/NVlabs/MUNIT/blob/master/networks.py
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self, num_blocks, dim, norm="in", activation="relu", pad_type="zero"):
|
159 |
+
super().__init__()
|
160 |
+
self.model = nn.Sequential(
|
161 |
+
*[
|
162 |
+
ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)
|
163 |
+
for _ in range(num_blocks)
|
164 |
+
]
|
165 |
+
)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
return self.model(x)
|
169 |
+
|
170 |
+
def __str__(self):
|
171 |
+
return strings.resblocks(self)
|
172 |
+
|
173 |
+
|
174 |
+
class ResBlock(nn.Module):
|
175 |
+
def __init__(self, dim, norm="in", activation="relu", pad_type="zero"):
|
176 |
+
super().__init__()
|
177 |
+
self.dim = dim
|
178 |
+
self.norm = norm
|
179 |
+
self.activation = activation
|
180 |
+
model = []
|
181 |
+
model += [
|
182 |
+
Conv2dBlock(
|
183 |
+
dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type
|
184 |
+
)
|
185 |
+
]
|
186 |
+
model += [
|
187 |
+
Conv2dBlock(
|
188 |
+
dim, dim, 3, 1, 1, norm=norm, activation="none", pad_type=pad_type
|
189 |
+
)
|
190 |
+
]
|
191 |
+
self.model = nn.Sequential(*model)
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
residual = x
|
195 |
+
out = self.model(x)
|
196 |
+
out += residual
|
197 |
+
return out
|
198 |
+
|
199 |
+
def __str__(self):
|
200 |
+
return strings.resblock(self)
|
201 |
+
|
202 |
+
|
203 |
+
# --------------------------
|
204 |
+
# ----- Base Decoder -----
|
205 |
+
# --------------------------
|
206 |
+
class BaseDecoder(nn.Module):
|
207 |
+
def __init__(
|
208 |
+
self,
|
209 |
+
n_upsample=4,
|
210 |
+
n_res=4,
|
211 |
+
input_dim=2048,
|
212 |
+
proj_dim=64,
|
213 |
+
output_dim=3,
|
214 |
+
norm="batch",
|
215 |
+
activ="relu",
|
216 |
+
pad_type="zero",
|
217 |
+
output_activ="tanh",
|
218 |
+
low_level_feats_dim=-1,
|
219 |
+
use_dada=False,
|
220 |
+
):
|
221 |
+
super().__init__()
|
222 |
+
|
223 |
+
self.low_level_feats_dim = low_level_feats_dim
|
224 |
+
self.use_dada = use_dada
|
225 |
+
|
226 |
+
self.model = []
|
227 |
+
if proj_dim != -1:
|
228 |
+
self.proj_conv = Conv2dBlock(
|
229 |
+
input_dim, proj_dim, 1, 1, 0, norm=norm, activation=activ
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
self.proj_conv = None
|
233 |
+
proj_dim = input_dim
|
234 |
+
|
235 |
+
if low_level_feats_dim > 0:
|
236 |
+
self.low_level_conv = Conv2dBlock(
|
237 |
+
input_dim=low_level_feats_dim,
|
238 |
+
output_dim=proj_dim,
|
239 |
+
kernel_size=3,
|
240 |
+
stride=1,
|
241 |
+
padding=1,
|
242 |
+
pad_type=pad_type,
|
243 |
+
norm=norm,
|
244 |
+
activation=activ,
|
245 |
+
)
|
246 |
+
self.merge_feats_conv = Conv2dBlock(
|
247 |
+
input_dim=2 * proj_dim,
|
248 |
+
output_dim=proj_dim,
|
249 |
+
kernel_size=1,
|
250 |
+
stride=1,
|
251 |
+
padding=0,
|
252 |
+
pad_type=pad_type,
|
253 |
+
norm=norm,
|
254 |
+
activation=activ,
|
255 |
+
)
|
256 |
+
else:
|
257 |
+
self.low_level_conv = None
|
258 |
+
|
259 |
+
self.model += [ResBlocks(n_res, proj_dim, norm, activ, pad_type=pad_type)]
|
260 |
+
dim = proj_dim
|
261 |
+
# upsampling blocks
|
262 |
+
for i in range(n_upsample):
|
263 |
+
self.model += [
|
264 |
+
InterpolateNearest2d(scale_factor=2),
|
265 |
+
Conv2dBlock(
|
266 |
+
input_dim=dim,
|
267 |
+
output_dim=dim // 2,
|
268 |
+
kernel_size=3,
|
269 |
+
stride=1,
|
270 |
+
padding=1,
|
271 |
+
pad_type=pad_type,
|
272 |
+
norm=norm,
|
273 |
+
activation=activ,
|
274 |
+
),
|
275 |
+
]
|
276 |
+
dim //= 2
|
277 |
+
# use reflection padding in the last conv layer
|
278 |
+
self.model += [
|
279 |
+
Conv2dBlock(
|
280 |
+
input_dim=dim,
|
281 |
+
output_dim=output_dim,
|
282 |
+
kernel_size=3,
|
283 |
+
stride=1,
|
284 |
+
padding=1,
|
285 |
+
pad_type=pad_type,
|
286 |
+
norm="none",
|
287 |
+
activation=output_activ,
|
288 |
+
)
|
289 |
+
]
|
290 |
+
self.model = nn.Sequential(*self.model)
|
291 |
+
|
292 |
+
def forward(self, z, cond=None, z_depth=None):
|
293 |
+
low_level_feat = None
|
294 |
+
if isinstance(z, (list, tuple)):
|
295 |
+
if self.low_level_conv is None:
|
296 |
+
z = z[0]
|
297 |
+
else:
|
298 |
+
z, low_level_feat = z
|
299 |
+
low_level_feat = self.low_level_conv(low_level_feat)
|
300 |
+
low_level_feat = F.interpolate(
|
301 |
+
low_level_feat, size=z.shape[-2:], mode="bilinear"
|
302 |
+
)
|
303 |
+
|
304 |
+
if z_depth is not None and self.use_dada:
|
305 |
+
z = z * z_depth
|
306 |
+
|
307 |
+
if self.proj_conv is not None:
|
308 |
+
z = self.proj_conv(z)
|
309 |
+
|
310 |
+
if low_level_feat is not None:
|
311 |
+
z = self.merge_feats_conv(torch.cat([low_level_feat, z], dim=1))
|
312 |
+
|
313 |
+
return self.model(z)
|
314 |
+
|
315 |
+
def __str__(self):
|
316 |
+
return strings.basedecoder(self)
|
317 |
+
|
318 |
+
|
319 |
+
# --------------------------
|
320 |
+
# ----- SPADE Blocks -----
|
321 |
+
# --------------------------
|
322 |
+
# https://github.com/NVlabs/SPADE/blob/0ff661e70131c9b85091d11a66e019c0f2062d4c
|
323 |
+
# /models/networks/generator.py
|
324 |
+
# 0ff661e on 13 Apr 2019
|
325 |
+
class SPADEResnetBlock(nn.Module):
|
326 |
+
def __init__(
|
327 |
+
self,
|
328 |
+
fin,
|
329 |
+
fout,
|
330 |
+
cond_nc,
|
331 |
+
spade_use_spectral_norm,
|
332 |
+
spade_param_free_norm,
|
333 |
+
spade_kernel_size,
|
334 |
+
last_activation=None,
|
335 |
+
):
|
336 |
+
super().__init__()
|
337 |
+
# Attributes
|
338 |
+
|
339 |
+
self.fin = fin
|
340 |
+
self.fout = fout
|
341 |
+
self.use_spectral_norm = spade_use_spectral_norm
|
342 |
+
self.param_free_norm = spade_param_free_norm
|
343 |
+
self.kernel_size = spade_kernel_size
|
344 |
+
|
345 |
+
self.learned_shortcut = fin != fout
|
346 |
+
self.last_activation = last_activation
|
347 |
+
fmiddle = min(fin, fout)
|
348 |
+
|
349 |
+
# create conv layers
|
350 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
351 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
352 |
+
if self.learned_shortcut:
|
353 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
354 |
+
|
355 |
+
# apply spectral norm if specified
|
356 |
+
if spade_use_spectral_norm:
|
357 |
+
self.conv_0 = SpectralNorm(self.conv_0)
|
358 |
+
self.conv_1 = SpectralNorm(self.conv_1)
|
359 |
+
if self.learned_shortcut:
|
360 |
+
self.conv_s = SpectralNorm(self.conv_s)
|
361 |
+
|
362 |
+
self.norm_0 = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc)
|
363 |
+
self.norm_1 = SPADE(spade_param_free_norm, spade_kernel_size, fmiddle, cond_nc)
|
364 |
+
if self.learned_shortcut:
|
365 |
+
self.norm_s = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc)
|
366 |
+
|
367 |
+
# note the resnet block with SPADE also takes in |seg|,
|
368 |
+
# the semantic segmentation map as input
|
369 |
+
def forward(self, x, seg):
|
370 |
+
x_s = self.shortcut(x, seg)
|
371 |
+
|
372 |
+
dx = self.conv_0(self.activation(self.norm_0(x, seg)))
|
373 |
+
dx = self.conv_1(self.activation(self.norm_1(dx, seg)))
|
374 |
+
|
375 |
+
out = x_s + dx
|
376 |
+
if self.last_activation == "lrelu":
|
377 |
+
return self.activation(out)
|
378 |
+
elif self.last_activation is None:
|
379 |
+
return out
|
380 |
+
else:
|
381 |
+
raise NotImplementedError(
|
382 |
+
"The type of activation is not supported: {}".format(
|
383 |
+
self.last_activation
|
384 |
+
)
|
385 |
+
)
|
386 |
+
|
387 |
+
def shortcut(self, x, seg):
|
388 |
+
if self.learned_shortcut:
|
389 |
+
x_s = self.conv_s(self.norm_s(x, seg))
|
390 |
+
else:
|
391 |
+
x_s = x
|
392 |
+
return x_s
|
393 |
+
|
394 |
+
def activation(self, x):
|
395 |
+
return F.leaky_relu(x, 2e-1)
|
396 |
+
|
397 |
+
def __str__(self):
|
398 |
+
return strings.spaderesblock(self)
|
climategan/bn_fusion.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from copy import deepcopy
|
3 |
+
|
4 |
+
|
5 |
+
class FlattableModel(object):
|
6 |
+
def __init__(self, model):
|
7 |
+
self.model = deepcopy(model)
|
8 |
+
self._original_model = model
|
9 |
+
self._flat_model = None
|
10 |
+
self._attr_names = self.get_attributes_name()
|
11 |
+
|
12 |
+
def flatten_model(self):
|
13 |
+
if self._flat_model is None:
|
14 |
+
self._flat_model = self._flatten_model(self.model)
|
15 |
+
return self._flat_model
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def _selection_method(module):
|
19 |
+
return not (
|
20 |
+
isinstance(module, torch.nn.Sequential)
|
21 |
+
or isinstance(module, torch.nn.ModuleList)
|
22 |
+
) and not hasattr(module, "_restricted")
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def _flatten_model(module):
|
26 |
+
modules = []
|
27 |
+
child = False
|
28 |
+
for (name, c) in module.named_children():
|
29 |
+
child = True
|
30 |
+
flattened_c = FlattableModel._flatten_model(c)
|
31 |
+
modules += flattened_c
|
32 |
+
if not child and FlattableModel._selection_method(module):
|
33 |
+
modules = [module]
|
34 |
+
return modules
|
35 |
+
|
36 |
+
def get_layer_io(self, layer, nb_samples, data_loader):
|
37 |
+
ios = []
|
38 |
+
hook = layer.register_forward_hook(
|
39 |
+
lambda m, i, o: ios.append((i[0].data.cpu(), o.data.cpu()))
|
40 |
+
)
|
41 |
+
|
42 |
+
nbatch = 1
|
43 |
+
for batch_idx, (xs, ys) in enumerate(data_loader):
|
44 |
+
# -1 takes all of them
|
45 |
+
if nb_samples != -1 and nbatch > nb_samples:
|
46 |
+
break
|
47 |
+
_ = self.model(xs.cuda())
|
48 |
+
nbatch += 1
|
49 |
+
|
50 |
+
hook.remove()
|
51 |
+
return ios
|
52 |
+
|
53 |
+
def get_attributes_name(self):
|
54 |
+
def _real_get_attributes_name(module):
|
55 |
+
modules = []
|
56 |
+
child = False
|
57 |
+
for (name, c) in module.named_children():
|
58 |
+
child = True
|
59 |
+
flattened_c = _real_get_attributes_name(c)
|
60 |
+
modules += map(lambda e: [name] + e, flattened_c)
|
61 |
+
if not child and FlattableModel._selection_method(module):
|
62 |
+
modules = [[]]
|
63 |
+
return modules
|
64 |
+
|
65 |
+
return _real_get_attributes_name(self.model)
|
66 |
+
|
67 |
+
def update_model(self, flat_model):
|
68 |
+
"""
|
69 |
+
Take a list representing the flatten model and rebuild its internals.
|
70 |
+
:type flat_model: List[nn.Module]
|
71 |
+
"""
|
72 |
+
|
73 |
+
def _apply_changes_on_layer(block, idxs, layer):
|
74 |
+
assert len(idxs) > 0
|
75 |
+
if len(idxs) == 1:
|
76 |
+
setattr(block, idxs[0], layer)
|
77 |
+
else:
|
78 |
+
_apply_changes_on_layer(getattr(block, idxs[0]), idxs[1:], layer)
|
79 |
+
|
80 |
+
def _apply_changes_model(model_list):
|
81 |
+
for i in range(len(model_list)):
|
82 |
+
_apply_changes_on_layer(self.model, self._attr_names[i], model_list[i])
|
83 |
+
|
84 |
+
_apply_changes_model(flat_model)
|
85 |
+
self._attr_names = self.get_attributes_name()
|
86 |
+
self._flat_model = None
|
87 |
+
|
88 |
+
def cuda(self):
|
89 |
+
self.model = self.model.cuda()
|
90 |
+
return self
|
91 |
+
|
92 |
+
def cpu(self):
|
93 |
+
self.model = self.model.cpu()
|
94 |
+
return self
|
95 |
+
|
96 |
+
|
97 |
+
def bn_fuse(model):
|
98 |
+
model = model.cpu()
|
99 |
+
flattable = FlattableModel(model)
|
100 |
+
fmodel = flattable.flatten_model()
|
101 |
+
|
102 |
+
for index, item in enumerate(fmodel):
|
103 |
+
if (
|
104 |
+
isinstance(item, torch.nn.Conv2d)
|
105 |
+
and index + 1 < len(fmodel)
|
106 |
+
and isinstance(fmodel[index + 1], torch.nn.BatchNorm2d)
|
107 |
+
):
|
108 |
+
alpha, beta = _calculate_alpha_beta(fmodel[index + 1])
|
109 |
+
if item.weight.shape[0] != alpha.shape[0]:
|
110 |
+
# this case happens if there was actually something else
|
111 |
+
# between the conv and the
|
112 |
+
# bn layer which is not picked up in flat model logic. (see densenet)
|
113 |
+
continue
|
114 |
+
item.weight.data = item.weight.data * alpha.view(-1, 1, 1, 1)
|
115 |
+
item.bias = torch.nn.Parameter(beta)
|
116 |
+
fmodel[index + 1] = _IdentityLayer()
|
117 |
+
flattable.update_model(fmodel)
|
118 |
+
return flattable.model
|
119 |
+
|
120 |
+
|
121 |
+
def _calculate_alpha_beta(batchnorm_layer):
|
122 |
+
alpha = batchnorm_layer.weight.data / (
|
123 |
+
torch.sqrt(batchnorm_layer.running_var + batchnorm_layer.eps)
|
124 |
+
)
|
125 |
+
beta = (
|
126 |
+
-(batchnorm_layer.weight.data * batchnorm_layer.running_mean)
|
127 |
+
/ (torch.sqrt(batchnorm_layer.running_var + batchnorm_layer.eps))
|
128 |
+
+ batchnorm_layer.bias.data
|
129 |
+
)
|
130 |
+
alpha = alpha.cpu()
|
131 |
+
beta = beta.cpu()
|
132 |
+
return alpha, beta
|
133 |
+
|
134 |
+
|
135 |
+
class _IdentityLayer(torch.nn.Module):
|
136 |
+
def forward(self, input):
|
137 |
+
return input
|
climategan/data.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Data-loading functions in order to create a Dataset and DataLoaders.
|
2 |
+
Transforms for loaders are in transforms.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import yaml
|
12 |
+
from imageio import imread
|
13 |
+
from PIL import Image
|
14 |
+
from torch.utils.data import DataLoader, Dataset
|
15 |
+
from torchvision import transforms
|
16 |
+
|
17 |
+
from climategan.transforms import get_transforms
|
18 |
+
from climategan.tutils import get_normalized_depth_t
|
19 |
+
from climategan.utils import env_to_path, is_image_file
|
20 |
+
|
21 |
+
classes_dict = {
|
22 |
+
"s": { # unity
|
23 |
+
0: [0, 0, 255, 255], # Water
|
24 |
+
1: [55, 55, 55, 255], # Ground
|
25 |
+
2: [0, 255, 255, 255], # Building
|
26 |
+
3: [255, 212, 0, 255], # Traffic items
|
27 |
+
4: [0, 255, 0, 255], # Vegetation
|
28 |
+
5: [255, 97, 0, 255], # Terrain
|
29 |
+
6: [255, 0, 0, 255], # Car
|
30 |
+
7: [60, 180, 60, 255], # Trees
|
31 |
+
8: [255, 0, 255, 255], # Person
|
32 |
+
9: [0, 0, 0, 255], # Sky
|
33 |
+
10: [255, 255, 255, 255], # Default
|
34 |
+
},
|
35 |
+
"r": { # deeplab v2
|
36 |
+
0: [0, 0, 255, 255], # Water
|
37 |
+
1: [55, 55, 55, 255], # Ground
|
38 |
+
2: [0, 255, 255, 255], # Building
|
39 |
+
3: [255, 212, 0, 255], # Traffic items
|
40 |
+
4: [0, 255, 0, 255], # Vegetation
|
41 |
+
5: [255, 97, 0, 255], # Terrain
|
42 |
+
6: [255, 0, 0, 255], # Car
|
43 |
+
7: [60, 180, 60, 255], # Trees
|
44 |
+
8: [220, 20, 60, 255], # Person
|
45 |
+
9: [8, 19, 49, 255], # Sky
|
46 |
+
10: [0, 80, 100, 255], # Default
|
47 |
+
},
|
48 |
+
"kitti": {
|
49 |
+
0: [210, 0, 200], # Terrain
|
50 |
+
1: [90, 200, 255], # Sky
|
51 |
+
2: [0, 199, 0], # Tree
|
52 |
+
3: [90, 240, 0], # Vegetation
|
53 |
+
4: [140, 140, 140], # Building
|
54 |
+
5: [100, 60, 100], # Road
|
55 |
+
6: [250, 100, 255], # GuardRail
|
56 |
+
7: [255, 255, 0], # TrafficSign
|
57 |
+
8: [200, 200, 0], # TrafficLight
|
58 |
+
9: [255, 130, 0], # Pole
|
59 |
+
10: [80, 80, 80], # Misc
|
60 |
+
11: [160, 60, 60], # Truck
|
61 |
+
12: [255, 127, 80], # Car
|
62 |
+
13: [0, 139, 139], # Van
|
63 |
+
14: [0, 0, 0], # Undefined
|
64 |
+
},
|
65 |
+
"flood": {
|
66 |
+
0: [255, 0, 0], # Cannot flood
|
67 |
+
1: [0, 0, 255], # Must flood
|
68 |
+
2: [0, 0, 0], # May flood
|
69 |
+
},
|
70 |
+
}
|
71 |
+
|
72 |
+
kitti_mapping = {
|
73 |
+
0: 5, # Terrain -> Terrain
|
74 |
+
1: 9, # Sky -> Sky
|
75 |
+
2: 7, # Tree -> Trees
|
76 |
+
3: 4, # Vegetation -> Vegetation
|
77 |
+
4: 2, # Building -> Building
|
78 |
+
5: 1, # Road -> Ground
|
79 |
+
6: 3, # GuardRail -> Traffic items
|
80 |
+
7: 3, # TrafficSign -> Traffic items
|
81 |
+
8: 3, # TrafficLight -> Traffic items
|
82 |
+
9: 3, # Pole -> Traffic items
|
83 |
+
10: 10, # Misc -> default
|
84 |
+
11: 6, # Truck -> Car
|
85 |
+
12: 6, # Car -> Car
|
86 |
+
13: 6, # Van -> Car
|
87 |
+
14: 10, # Undefined -> Default
|
88 |
+
}
|
89 |
+
|
90 |
+
|
91 |
+
def encode_exact_segmap(seg, classes_dict, default_value=14):
|
92 |
+
"""
|
93 |
+
When the mapping (rgb -> label) is known to be exact (no approximative rgb values)
|
94 |
+
maps rgb image to segmap labels
|
95 |
+
|
96 |
+
Args:
|
97 |
+
seg (np.ndarray): H x W x 3 RGB image
|
98 |
+
classes_dict (dict): Mapping {class: rgb value}
|
99 |
+
default_value (int, optional): Value for unknown label. Defaults to 14.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
np.ndarray: Segmap as labels, not RGB
|
103 |
+
"""
|
104 |
+
out = np.ones((seg.shape[0], seg.shape[1])) * default_value
|
105 |
+
for cindex, cvalue in classes_dict.items():
|
106 |
+
out[np.where((seg == cvalue).all(-1))] = cindex
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
def merge_labels(labels, mapping, default_value=14):
|
111 |
+
"""
|
112 |
+
Maps labels from a source domain to labels of a target domain,
|
113 |
+
typically kitti -> climategan
|
114 |
+
|
115 |
+
Args:
|
116 |
+
labels (np.ndarray): input segmap labels
|
117 |
+
mapping (dict): source_label -> target_label
|
118 |
+
default_value (int, optional): Unknown label. Defaults to 14.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
np.ndarray: Adapted labels
|
122 |
+
"""
|
123 |
+
out = np.ones_like(labels) * default_value
|
124 |
+
for source, target in mapping.items():
|
125 |
+
out[labels == source] = target
|
126 |
+
return out
|
127 |
+
|
128 |
+
|
129 |
+
def process_kitti_seg(path, kitti_classes, merge_map, default=14):
|
130 |
+
"""
|
131 |
+
Processes a path to produce a 1 x 1 x H x W torch segmap
|
132 |
+
|
133 |
+
%timeit process_kitti_seg(path, classes_dict, mapping, default=14)
|
134 |
+
326 ms ± 118 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
|
135 |
+
|
136 |
+
Args:
|
137 |
+
path (str | pathlib.Path): Segmap RBG path
|
138 |
+
kitti_classes (dict): Kitti map label -> rgb
|
139 |
+
merge_map (dict): map kitti_label -> climategan_label
|
140 |
+
default (int, optional): Unknown kitti label. Defaults to 14.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
torch.Tensor: 1 x 1 x H x W torch segmap
|
144 |
+
"""
|
145 |
+
seg = imread(path)
|
146 |
+
labels = encode_exact_segmap(seg, kitti_classes, default_value=default)
|
147 |
+
merged = merge_labels(labels, merge_map, default_value=default)
|
148 |
+
return torch.tensor(merged).unsqueeze(0).unsqueeze(0)
|
149 |
+
|
150 |
+
|
151 |
+
def decode_segmap_merged_labels(tensor, domain, is_target, nc=11):
|
152 |
+
"""Creates a label colormap for classes used in Unity segmentation benchmark.
|
153 |
+
Arguments:
|
154 |
+
tensor -- segmented image of size (1) x (nc) x (H) x (W)
|
155 |
+
if prediction, or size (1) x (1) x (H) x (W) if target
|
156 |
+
Returns:
|
157 |
+
RGB tensor of size (1) x (3) x (H) x (W)
|
158 |
+
#"""
|
159 |
+
|
160 |
+
if is_target: # Target is size 1 x 1 x H x W
|
161 |
+
idx = tensor.squeeze(0).squeeze(0)
|
162 |
+
else: # Prediction is size 1 x nc x H x W
|
163 |
+
idx = torch.argmax(tensor.squeeze(0), dim=0)
|
164 |
+
|
165 |
+
indexer = torch.tensor(list(classes_dict[domain].values()))[:, :3]
|
166 |
+
return indexer[idx.long()].permute(2, 0, 1).to(torch.float32).unsqueeze(0)
|
167 |
+
|
168 |
+
|
169 |
+
def decode_segmap_cityscapes_labels(image, nc=19):
|
170 |
+
"""Creates a label colormap used in CITYSCAPES segmentation benchmark.
|
171 |
+
Arguments:
|
172 |
+
image {array} -- segmented image
|
173 |
+
(array of image size containing class at each pixel)
|
174 |
+
Returns:
|
175 |
+
array of size 3*nc -- A colormap for visualizing segmentation results.
|
176 |
+
"""
|
177 |
+
colormap = np.zeros((19, 3), dtype=np.uint8)
|
178 |
+
colormap[0] = [128, 64, 128]
|
179 |
+
colormap[1] = [244, 35, 232]
|
180 |
+
colormap[2] = [70, 70, 70]
|
181 |
+
colormap[3] = [102, 102, 156]
|
182 |
+
colormap[4] = [190, 153, 153]
|
183 |
+
colormap[5] = [153, 153, 153]
|
184 |
+
colormap[6] = [250, 170, 30]
|
185 |
+
colormap[7] = [220, 220, 0]
|
186 |
+
colormap[8] = [107, 142, 35]
|
187 |
+
colormap[9] = [152, 251, 152]
|
188 |
+
colormap[10] = [70, 130, 180]
|
189 |
+
colormap[11] = [220, 20, 60]
|
190 |
+
colormap[12] = [255, 0, 0]
|
191 |
+
colormap[13] = [0, 0, 142]
|
192 |
+
colormap[14] = [0, 0, 70]
|
193 |
+
colormap[15] = [0, 60, 100]
|
194 |
+
colormap[16] = [0, 80, 100]
|
195 |
+
colormap[17] = [0, 0, 230]
|
196 |
+
colormap[18] = [119, 11, 32]
|
197 |
+
|
198 |
+
r = np.zeros_like(image).astype(np.uint8)
|
199 |
+
g = np.zeros_like(image).astype(np.uint8)
|
200 |
+
b = np.zeros_like(image).astype(np.uint8)
|
201 |
+
|
202 |
+
for col in range(nc):
|
203 |
+
idx = image == col
|
204 |
+
r[idx] = colormap[col, 0]
|
205 |
+
g[idx] = colormap[col, 1]
|
206 |
+
b[idx] = colormap[col, 2]
|
207 |
+
|
208 |
+
rgb = np.stack([r, g, b], axis=2)
|
209 |
+
return rgb
|
210 |
+
|
211 |
+
|
212 |
+
def find_closest_class(pixel, dict_classes):
|
213 |
+
"""Takes a pixel as input and finds the closest known pixel value corresponding
|
214 |
+
to a class in dict_classes
|
215 |
+
|
216 |
+
Arguments:
|
217 |
+
pixel -- tuple pixel (R,G,B,A)
|
218 |
+
Returns:
|
219 |
+
tuple pixel (R,G,B,A) corresponding to a key (a class) in dict_classes
|
220 |
+
"""
|
221 |
+
min_dist = float("inf")
|
222 |
+
closest_pixel = None
|
223 |
+
for pixel_value in dict_classes.keys():
|
224 |
+
dist = np.sqrt(np.sum(np.square(np.subtract(pixel, pixel_value))))
|
225 |
+
if dist < min_dist:
|
226 |
+
min_dist = dist
|
227 |
+
closest_pixel = pixel_value
|
228 |
+
return closest_pixel
|
229 |
+
|
230 |
+
|
231 |
+
def encode_segmap(arr, domain):
|
232 |
+
"""Change a segmentation RGBA array to a segmentation array
|
233 |
+
with each pixel being the index of the class
|
234 |
+
Arguments:
|
235 |
+
numpy array -- segmented image of size (H) x (W) x (4 RGBA values)
|
236 |
+
Returns:
|
237 |
+
numpy array of size (1) x (H) x (W) with each pixel being the index of the class
|
238 |
+
"""
|
239 |
+
new_arr = np.zeros((1, arr.shape[0], arr.shape[1]))
|
240 |
+
dict_classes = {
|
241 |
+
tuple(rgba_value): class_id
|
242 |
+
for (class_id, rgba_value) in classes_dict[domain].items()
|
243 |
+
}
|
244 |
+
for i in range(arr.shape[0]):
|
245 |
+
for j in range(arr.shape[1]):
|
246 |
+
pixel_rgba = tuple(arr[i, j, :])
|
247 |
+
if pixel_rgba in dict_classes.keys():
|
248 |
+
new_arr[0, i, j] = dict_classes[pixel_rgba]
|
249 |
+
else:
|
250 |
+
pixel_rgba_closest = find_closest_class(pixel_rgba, dict_classes)
|
251 |
+
new_arr[0, i, j] = dict_classes[pixel_rgba_closest]
|
252 |
+
return new_arr
|
253 |
+
|
254 |
+
|
255 |
+
def encode_mask_label(arr, domain):
|
256 |
+
"""Change a segmentation RGBA array to a segmentation array
|
257 |
+
with each pixel being the index of the class
|
258 |
+
Arguments:
|
259 |
+
numpy array -- segmented image of size (H) x (W) x (3 RGB values)
|
260 |
+
Returns:
|
261 |
+
numpy array of size (1) x (H) x (W) with each pixel being the index of the class
|
262 |
+
"""
|
263 |
+
diff = np.zeros((len(classes_dict[domain].keys()), arr.shape[0], arr.shape[1]))
|
264 |
+
for cindex, cvalue in classes_dict[domain].items():
|
265 |
+
diff[cindex, :, :] = np.sqrt(
|
266 |
+
np.sum(
|
267 |
+
np.square(arr - np.tile(cvalue, (arr.shape[0], arr.shape[1], 1))),
|
268 |
+
axis=2,
|
269 |
+
)
|
270 |
+
)
|
271 |
+
return np.expand_dims(np.argmin(diff, axis=0), axis=0)
|
272 |
+
|
273 |
+
|
274 |
+
def transform_segmap_image_to_tensor(path, domain):
|
275 |
+
"""
|
276 |
+
Transforms a segmentation image to a tensor of size (1) x (1) x (H) x (W)
|
277 |
+
with each pixel being the index of the class
|
278 |
+
"""
|
279 |
+
arr = np.array(Image.open(path).convert("RGBA"))
|
280 |
+
arr = encode_segmap(arr, domain)
|
281 |
+
arr = torch.from_numpy(arr).float()
|
282 |
+
arr = arr.unsqueeze(0)
|
283 |
+
return arr
|
284 |
+
|
285 |
+
|
286 |
+
def save_segmap_tensors(path_to_json, path_to_dir, domain):
|
287 |
+
"""
|
288 |
+
Loads the segmentation images mentionned in a json file, transforms them to
|
289 |
+
tensors and save the tensors in the wanted directory
|
290 |
+
|
291 |
+
Args:
|
292 |
+
path_to_json: complete path to the json file where to find the original data
|
293 |
+
path_to_dir: path to the directory where to save the tensors as tensor_name.pt
|
294 |
+
domain: domain of the images ("r" or "s")
|
295 |
+
|
296 |
+
e.g:
|
297 |
+
save_tensors(
|
298 |
+
"/network/tmp1/ccai/data/climategan/seg/train_s.json",
|
299 |
+
"/network/tmp1/ccai/data/munit_dataset/simdata/Unity11K_res640/Seg_tensors/",
|
300 |
+
"s",
|
301 |
+
)
|
302 |
+
"""
|
303 |
+
ims_list = None
|
304 |
+
if path_to_json:
|
305 |
+
path_to_json = Path(path_to_json).resolve()
|
306 |
+
with open(path_to_json, "r") as f:
|
307 |
+
ims_list = yaml.safe_load(f)
|
308 |
+
|
309 |
+
assert ims_list is not None
|
310 |
+
|
311 |
+
for im_dict in ims_list:
|
312 |
+
for task_name, path in im_dict.items():
|
313 |
+
if task_name == "s":
|
314 |
+
file_name = os.path.splitext(path)[0] # remove extension
|
315 |
+
file_name = file_name.rsplit("/", 1)[-1] # keep only the file_name
|
316 |
+
tensor = transform_segmap_image_to_tensor(path, domain)
|
317 |
+
torch.save(tensor, path_to_dir + file_name + ".pt")
|
318 |
+
|
319 |
+
|
320 |
+
def pil_image_loader(path, task):
|
321 |
+
if Path(path).suffix == ".npy":
|
322 |
+
arr = np.load(path).astype(np.uint8)
|
323 |
+
elif is_image_file(path):
|
324 |
+
# arr = imread(path).astype(np.uint8)
|
325 |
+
arr = np.array(Image.open(path).convert("RGB"))
|
326 |
+
else:
|
327 |
+
raise ValueError("Unknown data type {}".format(path))
|
328 |
+
|
329 |
+
# Convert from RGBA to RGB for images
|
330 |
+
if len(arr.shape) == 3 and arr.shape[-1] == 4:
|
331 |
+
arr = arr[:, :, 0:3]
|
332 |
+
|
333 |
+
if task == "m":
|
334 |
+
arr[arr != 0] = 1
|
335 |
+
# Make sure mask is single-channel
|
336 |
+
if len(arr.shape) >= 3:
|
337 |
+
arr = arr[:, :, 0]
|
338 |
+
|
339 |
+
# assert len(arr.shape) == 3, (path, task, arr.shape)
|
340 |
+
|
341 |
+
return Image.fromarray(arr)
|
342 |
+
|
343 |
+
|
344 |
+
def tensor_loader(path, task, domain, opts):
|
345 |
+
"""load data as tensors
|
346 |
+
Args:
|
347 |
+
path (str): path to data
|
348 |
+
task (str)
|
349 |
+
domain (str)
|
350 |
+
Returns:
|
351 |
+
[Tensor]: 1 x C x H x W
|
352 |
+
"""
|
353 |
+
if task == "s":
|
354 |
+
if domain == "kitti":
|
355 |
+
return process_kitti_seg(
|
356 |
+
path, classes_dict["kitti"], kitti_mapping, default=14
|
357 |
+
)
|
358 |
+
return torch.load(path)
|
359 |
+
elif task == "d":
|
360 |
+
if Path(path).suffix == ".npy":
|
361 |
+
arr = np.load(path)
|
362 |
+
else:
|
363 |
+
arr = imread(path) # .astype(np.uint8) /!\ kitti is np.uint16
|
364 |
+
tensor = torch.from_numpy(arr.astype(np.float32))
|
365 |
+
tensor = get_normalized_depth_t(
|
366 |
+
tensor,
|
367 |
+
domain,
|
368 |
+
normalize="d" in opts.train.pseudo.tasks,
|
369 |
+
log=opts.gen.d.classify.enable,
|
370 |
+
)
|
371 |
+
tensor = tensor.unsqueeze(0)
|
372 |
+
return tensor
|
373 |
+
|
374 |
+
elif Path(path).suffix == ".npy":
|
375 |
+
arr = np.load(path).astype(np.float32)
|
376 |
+
elif is_image_file(path):
|
377 |
+
arr = imread(path).astype(np.float32)
|
378 |
+
else:
|
379 |
+
raise ValueError("Unknown data type {}".format(path))
|
380 |
+
|
381 |
+
# Convert from RGBA to RGB for images
|
382 |
+
if len(arr.shape) == 3 and arr.shape[-1] == 4:
|
383 |
+
arr = arr[:, :, 0:3]
|
384 |
+
|
385 |
+
if task == "x":
|
386 |
+
arr -= arr.min()
|
387 |
+
arr /= arr.max()
|
388 |
+
arr = np.moveaxis(arr, 2, 0)
|
389 |
+
elif task == "s":
|
390 |
+
arr = np.moveaxis(arr, 2, 0)
|
391 |
+
elif task == "m":
|
392 |
+
if arr.max() > 127:
|
393 |
+
arr = (arr > 127).astype(arr.dtype)
|
394 |
+
# Make sure mask is single-channel
|
395 |
+
if len(arr.shape) >= 3:
|
396 |
+
arr = arr[:, :, 0]
|
397 |
+
arr = np.expand_dims(arr, 0)
|
398 |
+
|
399 |
+
return torch.from_numpy(arr).unsqueeze(0)
|
400 |
+
|
401 |
+
|
402 |
+
class OmniListDataset(Dataset):
|
403 |
+
def __init__(self, mode, domain, opts, transform=None):
|
404 |
+
|
405 |
+
self.opts = opts
|
406 |
+
self.domain = domain
|
407 |
+
self.mode = mode
|
408 |
+
self.tasks = set(opts.tasks)
|
409 |
+
self.tasks.add("x")
|
410 |
+
if "p" in self.tasks:
|
411 |
+
self.tasks.add("m")
|
412 |
+
|
413 |
+
file_list_path = Path(opts.data.files[mode][domain])
|
414 |
+
if "/" not in str(file_list_path):
|
415 |
+
file_list_path = Path(opts.data.files.base) / Path(
|
416 |
+
opts.data.files[mode][domain]
|
417 |
+
)
|
418 |
+
|
419 |
+
if file_list_path.suffix == ".json":
|
420 |
+
self.samples_paths = self.json_load(file_list_path)
|
421 |
+
elif file_list_path.suffix in {".yaml", ".yml"}:
|
422 |
+
self.samples_paths = self.yaml_load(file_list_path)
|
423 |
+
else:
|
424 |
+
raise ValueError("Unknown file list type in {}".format(file_list_path))
|
425 |
+
|
426 |
+
if opts.data.max_samples and opts.data.max_samples != -1:
|
427 |
+
assert isinstance(opts.data.max_samples, int)
|
428 |
+
self.samples_paths = self.samples_paths[: opts.data.max_samples]
|
429 |
+
|
430 |
+
self.filter_samples()
|
431 |
+
if opts.data.check_samples:
|
432 |
+
print(f"Checking samples ({mode}, {domain})")
|
433 |
+
self.check_samples()
|
434 |
+
self.file_list_path = str(file_list_path)
|
435 |
+
self.transform = transform
|
436 |
+
|
437 |
+
def filter_samples(self):
|
438 |
+
"""
|
439 |
+
Filter out data which is not required for the model's tasks
|
440 |
+
as defined in opts.tasks
|
441 |
+
"""
|
442 |
+
self.samples_paths = [
|
443 |
+
{k: v for k, v in s.items() if k in self.tasks} for s in self.samples_paths
|
444 |
+
]
|
445 |
+
|
446 |
+
def __getitem__(self, i):
|
447 |
+
"""Return an item in the dataset with fields:
|
448 |
+
{
|
449 |
+
data: transform({
|
450 |
+
domains: values
|
451 |
+
}),
|
452 |
+
paths: [{task: path}],
|
453 |
+
domain: [domain],
|
454 |
+
mode: [train|val]
|
455 |
+
}
|
456 |
+
Args:
|
457 |
+
i (int): index of item to retrieve
|
458 |
+
Returns:
|
459 |
+
dict: dataset item where tensors of data are in item["data"] which is a dict
|
460 |
+
{task: tensor}
|
461 |
+
"""
|
462 |
+
paths = self.samples_paths[i]
|
463 |
+
|
464 |
+
# always apply transforms,
|
465 |
+
# if no transform is specified, ToTensor and Normalize will be applied
|
466 |
+
|
467 |
+
item = {
|
468 |
+
"data": self.transform(
|
469 |
+
{
|
470 |
+
task: tensor_loader(
|
471 |
+
env_to_path(path),
|
472 |
+
task,
|
473 |
+
self.domain,
|
474 |
+
self.opts,
|
475 |
+
)
|
476 |
+
for task, path in paths.items()
|
477 |
+
}
|
478 |
+
),
|
479 |
+
"paths": paths,
|
480 |
+
"domain": self.domain if self.domain != "kitti" else "s",
|
481 |
+
"mode": self.mode,
|
482 |
+
}
|
483 |
+
|
484 |
+
return item
|
485 |
+
|
486 |
+
def __len__(self):
|
487 |
+
return len(self.samples_paths)
|
488 |
+
|
489 |
+
def json_load(self, file_path):
|
490 |
+
with open(file_path, "r") as f:
|
491 |
+
return json.load(f)
|
492 |
+
|
493 |
+
def yaml_load(self, file_path):
|
494 |
+
with open(file_path, "r") as f:
|
495 |
+
return yaml.safe_load(f)
|
496 |
+
|
497 |
+
def check_samples(self):
|
498 |
+
"""Checks that every file listed in samples_paths actually
|
499 |
+
exist on the file-system
|
500 |
+
"""
|
501 |
+
for s in self.samples_paths:
|
502 |
+
for k, v in s.items():
|
503 |
+
assert Path(v).exists(), f"{k} {v} does not exist"
|
504 |
+
|
505 |
+
|
506 |
+
def get_loader(mode, domain, opts):
|
507 |
+
if (
|
508 |
+
domain != "kitti"
|
509 |
+
or not opts.train.kitti.pretrain
|
510 |
+
or not opts.train.kitti.batch_size
|
511 |
+
):
|
512 |
+
batch_size = opts.data.loaders.get("batch_size", 4)
|
513 |
+
else:
|
514 |
+
batch_size = opts.train.kitti.get("batch_size", 4)
|
515 |
+
|
516 |
+
return DataLoader(
|
517 |
+
OmniListDataset(
|
518 |
+
mode,
|
519 |
+
domain,
|
520 |
+
opts,
|
521 |
+
transform=transforms.Compose(get_transforms(opts, mode, domain)),
|
522 |
+
),
|
523 |
+
batch_size=batch_size,
|
524 |
+
shuffle=True,
|
525 |
+
num_workers=opts.data.loaders.get("num_workers", 8),
|
526 |
+
pin_memory=True, # faster transfer to gpu
|
527 |
+
drop_last=True, # avoids batchnorm pbs if last batch has size 1
|
528 |
+
)
|
529 |
+
|
530 |
+
|
531 |
+
def get_all_loaders(opts):
|
532 |
+
loaders = {}
|
533 |
+
for mode in ["train", "val"]:
|
534 |
+
loaders[mode] = {}
|
535 |
+
for domain in opts.domains:
|
536 |
+
if mode in opts.data.files:
|
537 |
+
if domain in opts.data.files[mode]:
|
538 |
+
loaders[mode][domain] = get_loader(mode, domain, opts)
|
539 |
+
return loaders
|
climategan/deeplab/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from climategan.deeplab.deeplab_v2 import DeepLabV2Decoder
|
6 |
+
from climategan.deeplab.deeplab_v3 import DeepLabV3Decoder
|
7 |
+
from climategan.deeplab.mobilenet_v3 import MobileNetV2
|
8 |
+
from climategan.deeplab.resnet101_v3 import ResNet101
|
9 |
+
from climategan.deeplab.resnetmulti_v2 import ResNetMulti
|
10 |
+
|
11 |
+
|
12 |
+
def create_encoder(opts, no_init=False, verbose=0):
|
13 |
+
if opts.gen.encoder.architecture == "deeplabv2":
|
14 |
+
if verbose > 0:
|
15 |
+
print(" - Add Deeplabv2 Encoder")
|
16 |
+
return DeeplabV2Encoder(opts, no_init, verbose)
|
17 |
+
elif opts.gen.encoder.architecture == "deeplabv3":
|
18 |
+
if verbose > 0:
|
19 |
+
backone = opts.gen.deeplabv3.backbone
|
20 |
+
print(" - Add Deeplabv3 ({}) Encoder".format(backone))
|
21 |
+
return build_v3_backbone(opts, no_init)
|
22 |
+
else:
|
23 |
+
raise NotImplementedError(
|
24 |
+
"Unknown encoder: {}".format(opts.gen.encoder.architecture)
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
def create_segmentation_decoder(opts, no_init=False, verbose=0):
|
29 |
+
if opts.gen.s.architecture == "deeplabv2":
|
30 |
+
if verbose > 0:
|
31 |
+
print(" - Add DeepLabV2Decoder")
|
32 |
+
return DeepLabV2Decoder(opts)
|
33 |
+
elif opts.gen.s.architecture == "deeplabv3":
|
34 |
+
if verbose > 0:
|
35 |
+
print(" - Add DeepLabV3Decoder")
|
36 |
+
return DeepLabV3Decoder(opts, no_init)
|
37 |
+
else:
|
38 |
+
raise NotImplementedError(
|
39 |
+
"Unknown Segmentation architecture: {}".format(opts.gen.s.architecture)
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def build_v3_backbone(opts, no_init, verbose=0):
|
44 |
+
backbone = opts.gen.deeplabv3.backbone
|
45 |
+
output_stride = opts.gen.deeplabv3.output_stride
|
46 |
+
if backbone == "resnet":
|
47 |
+
resnet = ResNet101(
|
48 |
+
output_stride=output_stride,
|
49 |
+
BatchNorm=nn.BatchNorm2d,
|
50 |
+
verbose=verbose,
|
51 |
+
no_init=no_init,
|
52 |
+
)
|
53 |
+
if not no_init:
|
54 |
+
if opts.gen.deeplabv3.backbone == "resnet":
|
55 |
+
assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists()
|
56 |
+
|
57 |
+
std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
|
58 |
+
resnet.load_state_dict(
|
59 |
+
{
|
60 |
+
k.replace("backbone.", ""): v
|
61 |
+
for k, v in std.items()
|
62 |
+
if k.startswith("backbone.")
|
63 |
+
}
|
64 |
+
)
|
65 |
+
print(
|
66 |
+
" - Loaded pre-trained DeepLabv3+ Resnet101 Backbone as Encoder"
|
67 |
+
)
|
68 |
+
return resnet
|
69 |
+
|
70 |
+
elif opts.gen.deeplabv3.backbone == "mobilenet":
|
71 |
+
assert Path(opts.gen.deeplabv3.pretrained_model.mobilenet).exists()
|
72 |
+
mobilenet = MobileNetV2(
|
73 |
+
no_init=no_init,
|
74 |
+
pretrained_path=opts.gen.deeplabv3.pretrained_model.mobilenet,
|
75 |
+
)
|
76 |
+
print(" - Loaded pre-trained DeepLabv3+ MobileNetV2 Backbone as Encoder")
|
77 |
+
return mobilenet
|
78 |
+
|
79 |
+
else:
|
80 |
+
raise NotImplementedError("Unknown backbone in " + str(opts.gen.deeplabv3))
|
81 |
+
|
82 |
+
|
83 |
+
class DeeplabV2Encoder(nn.Module):
|
84 |
+
def __init__(self, opts, no_init=False, verbose=0):
|
85 |
+
"""Deeplab architecture encoder"""
|
86 |
+
super().__init__()
|
87 |
+
|
88 |
+
self.model = ResNetMulti(opts.gen.deeplabv2.nblocks, opts.gen.encoder.n_res)
|
89 |
+
if opts.gen.deeplabv2.use_pretrained and not no_init:
|
90 |
+
saved_state_dict = torch.load(opts.gen.deeplabv2.pretrained_model)
|
91 |
+
new_params = self.model.state_dict().copy()
|
92 |
+
for i in saved_state_dict:
|
93 |
+
i_parts = i.split(".")
|
94 |
+
if not i_parts[1] in ["layer5", "resblock"]:
|
95 |
+
new_params[".".join(i_parts[1:])] = saved_state_dict[i]
|
96 |
+
self.model.load_state_dict(new_params)
|
97 |
+
if verbose > 0:
|
98 |
+
print(" - Loaded pretrained weights")
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
return self.model(x)
|
climategan/deeplab/deeplab_v2.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from climategan.blocks import InterpolateNearest2d
|
5 |
+
from climategan.utils import find_target_size
|
6 |
+
|
7 |
+
|
8 |
+
class _ASPPModule(nn.Module):
|
9 |
+
# https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/aspp.py
|
10 |
+
def __init__(
|
11 |
+
self, inplanes, planes, kernel_size, padding, dilation, BatchNorm, no_init
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.atrous_conv = nn.Conv2d(
|
15 |
+
inplanes,
|
16 |
+
planes,
|
17 |
+
kernel_size=kernel_size,
|
18 |
+
stride=1,
|
19 |
+
padding=padding,
|
20 |
+
dilation=dilation,
|
21 |
+
bias=False,
|
22 |
+
)
|
23 |
+
self.bn = BatchNorm(planes)
|
24 |
+
self.relu = nn.ReLU()
|
25 |
+
if not no_init:
|
26 |
+
self._init_weight()
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.atrous_conv(x)
|
30 |
+
x = self.bn(x)
|
31 |
+
|
32 |
+
return self.relu(x)
|
33 |
+
|
34 |
+
def _init_weight(self):
|
35 |
+
for m in self.modules():
|
36 |
+
if isinstance(m, nn.Conv2d):
|
37 |
+
torch.nn.init.kaiming_normal_(m.weight)
|
38 |
+
elif isinstance(m, nn.BatchNorm2d):
|
39 |
+
m.weight.data.fill_(1)
|
40 |
+
m.bias.data.zero_()
|
41 |
+
|
42 |
+
|
43 |
+
class ASPP(nn.Module):
|
44 |
+
# https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/aspp.py
|
45 |
+
def __init__(self, backbone, output_stride, BatchNorm, no_init):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
if backbone == "mobilenet":
|
49 |
+
inplanes = 320
|
50 |
+
else:
|
51 |
+
inplanes = 2048
|
52 |
+
|
53 |
+
if output_stride == 16:
|
54 |
+
dilations = [1, 6, 12, 18]
|
55 |
+
elif output_stride == 8:
|
56 |
+
dilations = [1, 12, 24, 36]
|
57 |
+
else:
|
58 |
+
raise NotImplementedError
|
59 |
+
|
60 |
+
self.aspp1 = _ASPPModule(
|
61 |
+
inplanes,
|
62 |
+
256,
|
63 |
+
1,
|
64 |
+
padding=0,
|
65 |
+
dilation=dilations[0],
|
66 |
+
BatchNorm=BatchNorm,
|
67 |
+
no_init=no_init,
|
68 |
+
)
|
69 |
+
self.aspp2 = _ASPPModule(
|
70 |
+
inplanes,
|
71 |
+
256,
|
72 |
+
3,
|
73 |
+
padding=dilations[1],
|
74 |
+
dilation=dilations[1],
|
75 |
+
BatchNorm=BatchNorm,
|
76 |
+
no_init=no_init,
|
77 |
+
)
|
78 |
+
self.aspp3 = _ASPPModule(
|
79 |
+
inplanes,
|
80 |
+
256,
|
81 |
+
3,
|
82 |
+
padding=dilations[2],
|
83 |
+
dilation=dilations[2],
|
84 |
+
BatchNorm=BatchNorm,
|
85 |
+
no_init=no_init,
|
86 |
+
)
|
87 |
+
self.aspp4 = _ASPPModule(
|
88 |
+
inplanes,
|
89 |
+
256,
|
90 |
+
3,
|
91 |
+
padding=dilations[3],
|
92 |
+
dilation=dilations[3],
|
93 |
+
BatchNorm=BatchNorm,
|
94 |
+
no_init=no_init,
|
95 |
+
)
|
96 |
+
|
97 |
+
self.global_avg_pool = nn.Sequential(
|
98 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
99 |
+
nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
|
100 |
+
BatchNorm(256),
|
101 |
+
nn.ReLU(),
|
102 |
+
)
|
103 |
+
self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
|
104 |
+
self.bn1 = BatchNorm(256)
|
105 |
+
self.relu = nn.ReLU()
|
106 |
+
self.dropout = nn.Dropout(0.5)
|
107 |
+
if not no_init:
|
108 |
+
self._init_weight()
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
x1 = self.aspp1(x)
|
112 |
+
x2 = self.aspp2(x)
|
113 |
+
x3 = self.aspp3(x)
|
114 |
+
x4 = self.aspp4(x)
|
115 |
+
x5 = self.global_avg_pool(x)
|
116 |
+
x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
|
117 |
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
118 |
+
|
119 |
+
x = self.conv1(x)
|
120 |
+
x = self.bn1(x)
|
121 |
+
x = self.relu(x)
|
122 |
+
|
123 |
+
return self.dropout(x)
|
124 |
+
|
125 |
+
def _init_weight(self):
|
126 |
+
for m in self.modules():
|
127 |
+
if isinstance(m, nn.Conv2d):
|
128 |
+
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
129 |
+
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
130 |
+
torch.nn.init.kaiming_normal_(m.weight)
|
131 |
+
elif isinstance(m, nn.BatchNorm2d):
|
132 |
+
m.weight.data.fill_(1)
|
133 |
+
m.bias.data.zero_()
|
134 |
+
|
135 |
+
|
136 |
+
class DeepLabV2Decoder(nn.Module):
|
137 |
+
# https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/decoder.py
|
138 |
+
# https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/deeplab.py
|
139 |
+
def __init__(self, opts, no_init=False):
|
140 |
+
super().__init__()
|
141 |
+
self.aspp = ASPP("resnet", 16, nn.BatchNorm2d, no_init)
|
142 |
+
self.use_dada = ("d" in opts.tasks) and opts.gen.s.use_dada
|
143 |
+
|
144 |
+
conv_modules = [
|
145 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
|
146 |
+
nn.BatchNorm2d(256),
|
147 |
+
nn.ReLU(),
|
148 |
+
nn.Dropout(0.5),
|
149 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
|
150 |
+
nn.BatchNorm2d(256),
|
151 |
+
nn.ReLU(),
|
152 |
+
nn.Dropout(0.1),
|
153 |
+
]
|
154 |
+
if opts.gen.s.upsample_featuremaps:
|
155 |
+
conv_modules = [InterpolateNearest2d(scale_factor=2)] + conv_modules
|
156 |
+
|
157 |
+
conv_modules += [
|
158 |
+
nn.Conv2d(256, opts.gen.s.output_dim, kernel_size=1, stride=1),
|
159 |
+
]
|
160 |
+
self.conv = nn.Sequential(*conv_modules)
|
161 |
+
|
162 |
+
self._target_size = find_target_size(opts, "s")
|
163 |
+
print(
|
164 |
+
" - {}: setting target size to {}".format(
|
165 |
+
self.__class__.__name__, self._target_size
|
166 |
+
)
|
167 |
+
)
|
168 |
+
|
169 |
+
def set_target_size(self, size):
|
170 |
+
"""
|
171 |
+
Set final interpolation's target size
|
172 |
+
|
173 |
+
Args:
|
174 |
+
size (int, list, tuple): target size (h, w). If int, target will be (i, i)
|
175 |
+
"""
|
176 |
+
if isinstance(size, (list, tuple)):
|
177 |
+
self._target_size = size[:2]
|
178 |
+
else:
|
179 |
+
self._target_size = (size, size)
|
180 |
+
|
181 |
+
def forward(self, z, z_depth=None):
|
182 |
+
if self._target_size is None:
|
183 |
+
error = "self._target_size should be set with self.set_target_size()"
|
184 |
+
error += "to interpolate logits to the target seg map's size"
|
185 |
+
raise Exception(error)
|
186 |
+
if isinstance(z, (list, tuple)):
|
187 |
+
z = z[0]
|
188 |
+
if z.shape[1] != 2048:
|
189 |
+
raise Exception(
|
190 |
+
"Segmentation decoder will only work with 2048 channels for z"
|
191 |
+
)
|
192 |
+
|
193 |
+
if z_depth is not None and self.use_dada:
|
194 |
+
z = z * z_depth
|
195 |
+
|
196 |
+
y = self.aspp(z)
|
197 |
+
y = self.conv(y)
|
198 |
+
return F.interpolate(y, self._target_size, mode="bilinear", align_corners=True)
|
climategan/deeplab/deeplab_v3.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/resnet.py
|
3 |
+
"""
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from climategan.deeplab.mobilenet_v3 import SeparableConv2d
|
10 |
+
from climategan.utils import find_target_size
|
11 |
+
|
12 |
+
|
13 |
+
class _DeepLabHead(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self, nclass, c1_channels=256, c4_channels=2048, norm_layer=nn.BatchNorm2d
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
last_channels = c4_channels
|
19 |
+
# self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer)
|
20 |
+
# last_channels += 48
|
21 |
+
self.block = nn.Sequential(
|
22 |
+
SeparableConv2d(
|
23 |
+
last_channels, 256, 3, norm_layer=norm_layer, relu_first=False
|
24 |
+
),
|
25 |
+
SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False),
|
26 |
+
nn.Conv2d(256, nclass, 1),
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, x, c1=None):
|
30 |
+
return self.block(x)
|
31 |
+
|
32 |
+
|
33 |
+
class ConvBNReLU(nn.Module):
|
34 |
+
"""
|
35 |
+
https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self, in_chan, out_chan, ks=3, stride=1, padding=1, dilation=1, *args, **kwargs
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.conv = nn.Conv2d(
|
43 |
+
in_chan,
|
44 |
+
out_chan,
|
45 |
+
kernel_size=ks,
|
46 |
+
stride=stride,
|
47 |
+
padding=padding,
|
48 |
+
dilation=dilation,
|
49 |
+
bias=True,
|
50 |
+
)
|
51 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
52 |
+
self.init_weight()
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = self.conv(x)
|
56 |
+
x = self.bn(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
def init_weight(self):
|
60 |
+
for ly in self.children():
|
61 |
+
if isinstance(ly, nn.Conv2d):
|
62 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
63 |
+
if ly.bias is not None:
|
64 |
+
nn.init.constant_(ly.bias, 0)
|
65 |
+
|
66 |
+
|
67 |
+
class ASPPv3Plus(nn.Module):
|
68 |
+
"""
|
69 |
+
https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, backbone, no_init):
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
if backbone == "mobilenet":
|
76 |
+
in_chan = 320
|
77 |
+
else:
|
78 |
+
in_chan = 2048
|
79 |
+
|
80 |
+
self.with_gp = False
|
81 |
+
self.conv1 = ConvBNReLU(in_chan, 256, ks=1, dilation=1, padding=0)
|
82 |
+
self.conv2 = ConvBNReLU(in_chan, 256, ks=3, dilation=6, padding=6)
|
83 |
+
self.conv3 = ConvBNReLU(in_chan, 256, ks=3, dilation=12, padding=12)
|
84 |
+
self.conv4 = ConvBNReLU(in_chan, 256, ks=3, dilation=18, padding=18)
|
85 |
+
if self.with_gp:
|
86 |
+
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
87 |
+
self.conv1x1 = ConvBNReLU(in_chan, 256, ks=1)
|
88 |
+
self.conv_out = ConvBNReLU(256 * 5, 256, ks=1)
|
89 |
+
else:
|
90 |
+
self.conv_out = ConvBNReLU(256 * 4, 256, ks=1)
|
91 |
+
|
92 |
+
if not no_init:
|
93 |
+
self.init_weight()
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
H, W = x.size()[2:]
|
97 |
+
feat1 = self.conv1(x)
|
98 |
+
feat2 = self.conv2(x)
|
99 |
+
feat3 = self.conv3(x)
|
100 |
+
feat4 = self.conv4(x)
|
101 |
+
if self.with_gp:
|
102 |
+
avg = self.avg(x)
|
103 |
+
feat5 = self.conv1x1(avg)
|
104 |
+
feat5 = F.interpolate(feat5, (H, W), mode="bilinear", align_corners=True)
|
105 |
+
feat = torch.cat([feat1, feat2, feat3, feat4, feat5], 1)
|
106 |
+
else:
|
107 |
+
feat = torch.cat([feat1, feat2, feat3, feat4], 1)
|
108 |
+
feat = self.conv_out(feat)
|
109 |
+
return feat
|
110 |
+
|
111 |
+
def init_weight(self):
|
112 |
+
for ly in self.children():
|
113 |
+
if isinstance(ly, nn.Conv2d):
|
114 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
115 |
+
if ly.bias is not None:
|
116 |
+
nn.init.constant_(ly.bias, 0)
|
117 |
+
|
118 |
+
|
119 |
+
class Decoder(nn.Module):
|
120 |
+
"""
|
121 |
+
https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, n_classes):
|
125 |
+
super(Decoder, self).__init__()
|
126 |
+
self.conv_low = ConvBNReLU(256, 48, ks=1, padding=0)
|
127 |
+
self.conv_cat = nn.Sequential(
|
128 |
+
ConvBNReLU(304, 256, ks=3, padding=1),
|
129 |
+
ConvBNReLU(256, 256, ks=3, padding=1),
|
130 |
+
)
|
131 |
+
self.conv_out = nn.Conv2d(256, n_classes, kernel_size=1, bias=False)
|
132 |
+
|
133 |
+
def forward(self, feat_low, feat_aspp):
|
134 |
+
H, W = feat_low.size()[2:]
|
135 |
+
feat_low = self.conv_low(feat_low)
|
136 |
+
feat_aspp_up = F.interpolate(
|
137 |
+
feat_aspp, (H, W), mode="bilinear", align_corners=True
|
138 |
+
)
|
139 |
+
feat_cat = torch.cat([feat_low, feat_aspp_up], dim=1)
|
140 |
+
feat_out = self.conv_cat(feat_cat)
|
141 |
+
logits = self.conv_out(feat_out)
|
142 |
+
return logits
|
143 |
+
|
144 |
+
|
145 |
+
"""
|
146 |
+
https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/deeplab.py
|
147 |
+
"""
|
148 |
+
|
149 |
+
|
150 |
+
class DeepLabV3Decoder(nn.Module):
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
opts,
|
154 |
+
no_init=False,
|
155 |
+
freeze_bn=False,
|
156 |
+
):
|
157 |
+
super().__init__()
|
158 |
+
|
159 |
+
num_classes = opts.gen.s.output_dim
|
160 |
+
self.backbone = opts.gen.deeplabv3.backbone
|
161 |
+
self.use_dada = ("d" in opts.tasks) and opts.gen.s.use_dada
|
162 |
+
|
163 |
+
if self.backbone == "resnet":
|
164 |
+
self.aspp = ASPPv3Plus(self.backbone, no_init)
|
165 |
+
self.decoder = Decoder(num_classes)
|
166 |
+
|
167 |
+
self.freeze_bn = freeze_bn
|
168 |
+
else:
|
169 |
+
self.head = _DeepLabHead(num_classes, c4_channels=320)
|
170 |
+
|
171 |
+
self._target_size = find_target_size(opts, "s")
|
172 |
+
print(
|
173 |
+
" - {}: setting target size to {}".format(
|
174 |
+
self.__class__.__name__, self._target_size
|
175 |
+
)
|
176 |
+
)
|
177 |
+
|
178 |
+
if not no_init:
|
179 |
+
for m in self.modules():
|
180 |
+
if isinstance(m, nn.Conv2d):
|
181 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
182 |
+
if m.bias is not None:
|
183 |
+
nn.init.zeros_(m.bias)
|
184 |
+
elif isinstance(m, nn.BatchNorm2d):
|
185 |
+
nn.init.ones_(m.weight)
|
186 |
+
nn.init.zeros_(m.bias)
|
187 |
+
elif isinstance(m, nn.Linear):
|
188 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
189 |
+
nn.init.zeros_(m.bias)
|
190 |
+
|
191 |
+
self.load_pretrained(opts)
|
192 |
+
|
193 |
+
def load_pretrained(self, opts):
|
194 |
+
assert opts.gen.deeplabv3.backbone in {"resnet", "mobilenet"}
|
195 |
+
assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists()
|
196 |
+
if opts.gen.deeplabv3.backbone == "resnet":
|
197 |
+
std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
|
198 |
+
self.aspp.load_state_dict(
|
199 |
+
{
|
200 |
+
k.replace("aspp.", ""): v
|
201 |
+
for k, v in std.items()
|
202 |
+
if k.startswith("aspp.")
|
203 |
+
}
|
204 |
+
)
|
205 |
+
self.decoder.load_state_dict(
|
206 |
+
{
|
207 |
+
k.replace("decoder.", ""): v
|
208 |
+
for k, v in std.items()
|
209 |
+
if k.startswith("decoder.")
|
210 |
+
and not (len(v.shape) > 0 and v.shape[0] == 19)
|
211 |
+
},
|
212 |
+
strict=False,
|
213 |
+
)
|
214 |
+
print(
|
215 |
+
"- Loaded pre-trained DeepLabv3+ (Resnet) Decoder & ASPP as Seg Decoder"
|
216 |
+
)
|
217 |
+
else:
|
218 |
+
std = torch.load(opts.gen.deeplabv3.pretrained_model.mobilenet)
|
219 |
+
self.load_state_dict(
|
220 |
+
{
|
221 |
+
k: v
|
222 |
+
for k, v in std.items()
|
223 |
+
if k.startswith("head.")
|
224 |
+
and not (len(v.shape) > 0 and v.shape[0] == 19)
|
225 |
+
},
|
226 |
+
strict=False,
|
227 |
+
)
|
228 |
+
print(
|
229 |
+
" - Loaded pre-trained DeepLabv3+ (MobileNetV2) Head as Seg Decoder"
|
230 |
+
)
|
231 |
+
|
232 |
+
def set_target_size(self, size):
|
233 |
+
"""
|
234 |
+
Set final interpolation's target size
|
235 |
+
|
236 |
+
Args:
|
237 |
+
size (int, list, tuple): target size (h, w). If int, target will be (i, i)
|
238 |
+
"""
|
239 |
+
if isinstance(size, (list, tuple)):
|
240 |
+
self._target_size = size[:2]
|
241 |
+
else:
|
242 |
+
self._target_size = (size, size)
|
243 |
+
|
244 |
+
def forward(self, z, z_depth=None):
|
245 |
+
assert isinstance(z, (tuple, list))
|
246 |
+
if self._target_size is None:
|
247 |
+
error = "self._target_size should be set with self.set_target_size()"
|
248 |
+
error += "to interpolate logits to the target seg map's size"
|
249 |
+
raise ValueError(error)
|
250 |
+
|
251 |
+
z_high, z_low = z
|
252 |
+
|
253 |
+
if z_depth is not None and self.use_dada:
|
254 |
+
z_high = z_high * z_depth
|
255 |
+
|
256 |
+
if self.backbone == "resnet":
|
257 |
+
z_high = self.aspp(z_high)
|
258 |
+
s = self.decoder(z_high, z_low)
|
259 |
+
else:
|
260 |
+
s = self.head(z_high)
|
261 |
+
|
262 |
+
s = F.interpolate(
|
263 |
+
s, size=self._target_size, mode="bilinear", align_corners=True
|
264 |
+
)
|
265 |
+
|
266 |
+
return s
|
267 |
+
|
268 |
+
def freeze_bn(self):
|
269 |
+
for m in self.modules():
|
270 |
+
if isinstance(m, nn.BatchNorm2d):
|
271 |
+
m.eval()
|
climategan/deeplab/mobilenet_v3.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
from https://github.com/LikeLy-Journey/SegmenTron/blob/
|
3 |
+
4bc605eedde7d680314f63d329277b73f83b1c5f/segmentron/modules/basic.py#L34
|
4 |
+
"""
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from climategan.blocks import InterpolateNearest2d
|
12 |
+
|
13 |
+
|
14 |
+
class SeparableConv2d(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
inplanes,
|
18 |
+
planes,
|
19 |
+
kernel_size=3,
|
20 |
+
stride=1,
|
21 |
+
dilation=1,
|
22 |
+
relu_first=True,
|
23 |
+
bias=False,
|
24 |
+
norm_layer=nn.BatchNorm2d,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
depthwise = nn.Conv2d(
|
28 |
+
inplanes,
|
29 |
+
inplanes,
|
30 |
+
kernel_size,
|
31 |
+
stride=stride,
|
32 |
+
padding=dilation,
|
33 |
+
dilation=dilation,
|
34 |
+
groups=inplanes,
|
35 |
+
bias=bias,
|
36 |
+
)
|
37 |
+
bn_depth = norm_layer(inplanes)
|
38 |
+
pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
|
39 |
+
bn_point = norm_layer(planes)
|
40 |
+
|
41 |
+
if relu_first:
|
42 |
+
self.block = nn.Sequential(
|
43 |
+
OrderedDict(
|
44 |
+
[
|
45 |
+
("relu", nn.ReLU()),
|
46 |
+
("depthwise", depthwise),
|
47 |
+
("bn_depth", bn_depth),
|
48 |
+
("pointwise", pointwise),
|
49 |
+
("bn_point", bn_point),
|
50 |
+
]
|
51 |
+
)
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
self.block = nn.Sequential(
|
55 |
+
OrderedDict(
|
56 |
+
[
|
57 |
+
("depthwise", depthwise),
|
58 |
+
("bn_depth", bn_depth),
|
59 |
+
("relu1", nn.ReLU(inplace=True)),
|
60 |
+
("pointwise", pointwise),
|
61 |
+
("bn_point", bn_point),
|
62 |
+
("relu2", nn.ReLU(inplace=True)),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
return self.block(x)
|
69 |
+
|
70 |
+
|
71 |
+
class _ConvBNReLU(nn.Module):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
in_channels,
|
75 |
+
out_channels,
|
76 |
+
kernel_size,
|
77 |
+
stride=1,
|
78 |
+
padding=0,
|
79 |
+
dilation=1,
|
80 |
+
groups=1,
|
81 |
+
relu6=False,
|
82 |
+
norm_layer=nn.BatchNorm2d,
|
83 |
+
):
|
84 |
+
super(_ConvBNReLU, self).__init__()
|
85 |
+
self.conv = nn.Conv2d(
|
86 |
+
in_channels,
|
87 |
+
out_channels,
|
88 |
+
kernel_size,
|
89 |
+
stride,
|
90 |
+
padding,
|
91 |
+
dilation,
|
92 |
+
groups,
|
93 |
+
bias=False,
|
94 |
+
)
|
95 |
+
self.bn = norm_layer(out_channels)
|
96 |
+
self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
x = self.conv(x)
|
100 |
+
x = self.bn(x)
|
101 |
+
x = self.relu(x)
|
102 |
+
return x
|
103 |
+
|
104 |
+
|
105 |
+
class _DepthwiseConv(nn.Module):
|
106 |
+
"""conv_dw in MobileNet"""
|
107 |
+
|
108 |
+
def __init__(
|
109 |
+
self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs
|
110 |
+
):
|
111 |
+
super(_DepthwiseConv, self).__init__()
|
112 |
+
self.conv = nn.Sequential(
|
113 |
+
_ConvBNReLU(
|
114 |
+
in_channels,
|
115 |
+
in_channels,
|
116 |
+
3,
|
117 |
+
stride,
|
118 |
+
1,
|
119 |
+
groups=in_channels,
|
120 |
+
norm_layer=norm_layer,
|
121 |
+
),
|
122 |
+
_ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer),
|
123 |
+
)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
return self.conv(x)
|
127 |
+
|
128 |
+
|
129 |
+
class InvertedResidual(nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
in_channels,
|
133 |
+
out_channels,
|
134 |
+
stride,
|
135 |
+
expand_ratio,
|
136 |
+
dilation=1,
|
137 |
+
norm_layer=nn.BatchNorm2d,
|
138 |
+
):
|
139 |
+
super(InvertedResidual, self).__init__()
|
140 |
+
assert stride in [1, 2]
|
141 |
+
self.use_res_connect = stride == 1 and in_channels == out_channels
|
142 |
+
|
143 |
+
layers = list()
|
144 |
+
inter_channels = int(round(in_channels * expand_ratio))
|
145 |
+
if expand_ratio != 1:
|
146 |
+
# pw
|
147 |
+
layers.append(
|
148 |
+
_ConvBNReLU(
|
149 |
+
in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer
|
150 |
+
)
|
151 |
+
)
|
152 |
+
layers.extend(
|
153 |
+
[
|
154 |
+
# dw
|
155 |
+
_ConvBNReLU(
|
156 |
+
inter_channels,
|
157 |
+
inter_channels,
|
158 |
+
3,
|
159 |
+
stride,
|
160 |
+
dilation,
|
161 |
+
dilation,
|
162 |
+
groups=inter_channels,
|
163 |
+
relu6=True,
|
164 |
+
norm_layer=norm_layer,
|
165 |
+
),
|
166 |
+
# pw-linear
|
167 |
+
nn.Conv2d(inter_channels, out_channels, 1, bias=False),
|
168 |
+
norm_layer(out_channels),
|
169 |
+
]
|
170 |
+
)
|
171 |
+
self.conv = nn.Sequential(*layers)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
if self.use_res_connect:
|
175 |
+
return x + self.conv(x)
|
176 |
+
else:
|
177 |
+
return self.conv(x)
|
178 |
+
|
179 |
+
|
180 |
+
class MobileNetV2(nn.Module):
|
181 |
+
def __init__(self, norm_layer=nn.BatchNorm2d, pretrained_path=None, no_init=False):
|
182 |
+
super(MobileNetV2, self).__init__()
|
183 |
+
output_stride = 16
|
184 |
+
self.multiplier = 1.0
|
185 |
+
if output_stride == 32:
|
186 |
+
dilations = [1, 1]
|
187 |
+
elif output_stride == 16:
|
188 |
+
dilations = [1, 2]
|
189 |
+
elif output_stride == 8:
|
190 |
+
dilations = [2, 4]
|
191 |
+
else:
|
192 |
+
raise NotImplementedError
|
193 |
+
inverted_residual_setting = [
|
194 |
+
# t, c, n, s
|
195 |
+
[1, 16, 1, 1],
|
196 |
+
[6, 24, 2, 2],
|
197 |
+
[6, 32, 3, 2],
|
198 |
+
[6, 64, 4, 2],
|
199 |
+
[6, 96, 3, 1],
|
200 |
+
[6, 160, 3, 2],
|
201 |
+
[6, 320, 1, 1],
|
202 |
+
]
|
203 |
+
# building first layer
|
204 |
+
input_channels = int(32 * self.multiplier) if self.multiplier > 1.0 else 32
|
205 |
+
# last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280
|
206 |
+
self.conv1 = _ConvBNReLU(
|
207 |
+
3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer
|
208 |
+
)
|
209 |
+
|
210 |
+
# building inverted residual blocks
|
211 |
+
self.planes = input_channels
|
212 |
+
self.block1 = self._make_layer(
|
213 |
+
InvertedResidual,
|
214 |
+
self.planes,
|
215 |
+
inverted_residual_setting[0:1],
|
216 |
+
norm_layer=norm_layer,
|
217 |
+
)
|
218 |
+
self.block2 = self._make_layer(
|
219 |
+
InvertedResidual,
|
220 |
+
self.planes,
|
221 |
+
inverted_residual_setting[1:2],
|
222 |
+
norm_layer=norm_layer,
|
223 |
+
)
|
224 |
+
self.block3 = self._make_layer(
|
225 |
+
InvertedResidual,
|
226 |
+
self.planes,
|
227 |
+
inverted_residual_setting[2:3],
|
228 |
+
norm_layer=norm_layer,
|
229 |
+
)
|
230 |
+
self.block4 = self._make_layer(
|
231 |
+
InvertedResidual,
|
232 |
+
self.planes,
|
233 |
+
inverted_residual_setting[3:5],
|
234 |
+
dilations[0],
|
235 |
+
norm_layer=norm_layer,
|
236 |
+
)
|
237 |
+
self.block5 = self._make_layer(
|
238 |
+
InvertedResidual,
|
239 |
+
self.planes,
|
240 |
+
inverted_residual_setting[5:],
|
241 |
+
dilations[1],
|
242 |
+
norm_layer=norm_layer,
|
243 |
+
)
|
244 |
+
self.last_inp_channels = self.planes
|
245 |
+
|
246 |
+
self.up2 = InterpolateNearest2d()
|
247 |
+
|
248 |
+
# weight initialization
|
249 |
+
if not no_init:
|
250 |
+
self.pretrained_path = pretrained_path
|
251 |
+
if pretrained_path is not None:
|
252 |
+
self._load_pretrained_model()
|
253 |
+
else:
|
254 |
+
for m in self.modules():
|
255 |
+
if isinstance(m, nn.Conv2d):
|
256 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
257 |
+
if m.bias is not None:
|
258 |
+
nn.init.zeros_(m.bias)
|
259 |
+
elif isinstance(m, nn.BatchNorm2d):
|
260 |
+
nn.init.ones_(m.weight)
|
261 |
+
nn.init.zeros_(m.bias)
|
262 |
+
elif isinstance(m, nn.Linear):
|
263 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
264 |
+
if m.bias is not None:
|
265 |
+
nn.init.zeros_(m.bias)
|
266 |
+
|
267 |
+
def _make_layer(
|
268 |
+
self,
|
269 |
+
block,
|
270 |
+
planes,
|
271 |
+
inverted_residual_setting,
|
272 |
+
dilation=1,
|
273 |
+
norm_layer=nn.BatchNorm2d,
|
274 |
+
):
|
275 |
+
features = list()
|
276 |
+
for t, c, n, s in inverted_residual_setting:
|
277 |
+
out_channels = int(c * self.multiplier)
|
278 |
+
stride = s if dilation == 1 else 1
|
279 |
+
features.append(
|
280 |
+
block(planes, out_channels, stride, t, dilation, norm_layer)
|
281 |
+
)
|
282 |
+
planes = out_channels
|
283 |
+
for i in range(n - 1):
|
284 |
+
features.append(
|
285 |
+
block(planes, out_channels, 1, t, norm_layer=norm_layer)
|
286 |
+
)
|
287 |
+
planes = out_channels
|
288 |
+
self.planes = planes
|
289 |
+
return nn.Sequential(*features)
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
x = self.conv1(x)
|
293 |
+
x = self.block1(x)
|
294 |
+
c1 = self.block2(x)
|
295 |
+
c2 = self.block3(c1)
|
296 |
+
c3 = self.block4(c2)
|
297 |
+
c4 = self.up2(self.block5(c3))
|
298 |
+
|
299 |
+
# x = self.features(x)
|
300 |
+
# x = self.classifier(x.view(x.size(0), x.size(1)))
|
301 |
+
return c4, c1
|
302 |
+
|
303 |
+
def _load_pretrained_model(self):
|
304 |
+
assert self.pretrained_path is not None
|
305 |
+
assert Path(self.pretrained_path).exists()
|
306 |
+
|
307 |
+
pretrain_dict = torch.load(self.pretrained_path)
|
308 |
+
pretrain_dict = {k.replace("encoder.", ""): v for k, v in pretrain_dict.items()}
|
309 |
+
model_dict = {}
|
310 |
+
state_dict = self.state_dict()
|
311 |
+
ignored = []
|
312 |
+
for k, v in pretrain_dict.items():
|
313 |
+
if k in state_dict:
|
314 |
+
model_dict[k] = v
|
315 |
+
else:
|
316 |
+
ignored.append(k)
|
317 |
+
state_dict.update(model_dict)
|
318 |
+
self.load_state_dict(state_dict)
|
319 |
+
self.loaded_pre_trained = True
|
320 |
+
print(
|
321 |
+
" - Loaded pre-trained MobileNetV2: ignored {}/{} keys".format(
|
322 |
+
len(ignored), len(pretrain_dict)
|
323 |
+
)
|
324 |
+
)
|
climategan/deeplab/resnet101_v3.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class Bottleneck(nn.Module):
|
5 |
+
expansion = 4
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None
|
9 |
+
):
|
10 |
+
super(Bottleneck, self).__init__()
|
11 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
12 |
+
self.bn1 = BatchNorm(planes)
|
13 |
+
self.conv2 = nn.Conv2d(
|
14 |
+
planes,
|
15 |
+
planes,
|
16 |
+
kernel_size=3,
|
17 |
+
stride=stride,
|
18 |
+
dilation=dilation,
|
19 |
+
padding=dilation,
|
20 |
+
bias=False,
|
21 |
+
)
|
22 |
+
self.bn2 = BatchNorm(planes)
|
23 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
24 |
+
self.bn3 = BatchNorm(planes * 4)
|
25 |
+
self.relu = nn.ReLU(inplace=True)
|
26 |
+
self.downsample = downsample
|
27 |
+
self.stride = stride
|
28 |
+
self.dilation = dilation
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
residual = x
|
32 |
+
|
33 |
+
out = self.conv1(x)
|
34 |
+
out = self.bn1(out)
|
35 |
+
out = self.relu(out)
|
36 |
+
|
37 |
+
out = self.conv2(out)
|
38 |
+
out = self.bn2(out)
|
39 |
+
out = self.relu(out)
|
40 |
+
|
41 |
+
out = self.conv3(out)
|
42 |
+
out = self.bn3(out)
|
43 |
+
|
44 |
+
if self.downsample is not None:
|
45 |
+
residual = self.downsample(x)
|
46 |
+
|
47 |
+
out += residual
|
48 |
+
out = self.relu(out)
|
49 |
+
|
50 |
+
return out
|
51 |
+
|
52 |
+
|
53 |
+
class ResNet(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self, block, layers, output_stride, BatchNorm, verbose=0, no_init=False
|
56 |
+
):
|
57 |
+
self.inplanes = 64
|
58 |
+
self.verbose = verbose
|
59 |
+
super(ResNet, self).__init__()
|
60 |
+
blocks = [1, 2, 4]
|
61 |
+
if output_stride == 16:
|
62 |
+
strides = [1, 2, 2, 1]
|
63 |
+
dilations = [1, 1, 1, 2]
|
64 |
+
elif output_stride == 8:
|
65 |
+
strides = [1, 2, 1, 1]
|
66 |
+
dilations = [1, 1, 2, 4]
|
67 |
+
else:
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
# Modules
|
71 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
72 |
+
self.bn1 = BatchNorm(64)
|
73 |
+
self.relu = nn.ReLU(inplace=True)
|
74 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
75 |
+
|
76 |
+
self.layer1 = self._make_layer(
|
77 |
+
block,
|
78 |
+
64,
|
79 |
+
layers[0],
|
80 |
+
stride=strides[0],
|
81 |
+
dilation=dilations[0],
|
82 |
+
BatchNorm=BatchNorm,
|
83 |
+
)
|
84 |
+
self.layer2 = self._make_layer(
|
85 |
+
block,
|
86 |
+
128,
|
87 |
+
layers[1],
|
88 |
+
stride=strides[1],
|
89 |
+
dilation=dilations[1],
|
90 |
+
BatchNorm=BatchNorm,
|
91 |
+
)
|
92 |
+
self.layer3 = self._make_layer(
|
93 |
+
block,
|
94 |
+
256,
|
95 |
+
layers[2],
|
96 |
+
stride=strides[2],
|
97 |
+
dilation=dilations[2],
|
98 |
+
BatchNorm=BatchNorm,
|
99 |
+
)
|
100 |
+
self.layer4 = self._make_MG_unit(
|
101 |
+
block,
|
102 |
+
512,
|
103 |
+
blocks=blocks,
|
104 |
+
stride=strides[3],
|
105 |
+
dilation=dilations[3],
|
106 |
+
BatchNorm=BatchNorm,
|
107 |
+
)
|
108 |
+
|
109 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
110 |
+
downsample = None
|
111 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
112 |
+
downsample = nn.Sequential(
|
113 |
+
nn.Conv2d(
|
114 |
+
self.inplanes,
|
115 |
+
planes * block.expansion,
|
116 |
+
kernel_size=1,
|
117 |
+
stride=stride,
|
118 |
+
bias=False,
|
119 |
+
),
|
120 |
+
BatchNorm(planes * block.expansion),
|
121 |
+
)
|
122 |
+
|
123 |
+
layers = []
|
124 |
+
layers.append(
|
125 |
+
block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)
|
126 |
+
)
|
127 |
+
self.inplanes = planes * block.expansion
|
128 |
+
for i in range(1, blocks):
|
129 |
+
layers.append(
|
130 |
+
block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)
|
131 |
+
)
|
132 |
+
|
133 |
+
return nn.Sequential(*layers)
|
134 |
+
|
135 |
+
def _make_MG_unit(
|
136 |
+
self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None
|
137 |
+
):
|
138 |
+
downsample = None
|
139 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
140 |
+
downsample = nn.Sequential(
|
141 |
+
nn.Conv2d(
|
142 |
+
self.inplanes,
|
143 |
+
planes * block.expansion,
|
144 |
+
kernel_size=1,
|
145 |
+
stride=stride,
|
146 |
+
bias=False,
|
147 |
+
),
|
148 |
+
BatchNorm(planes * block.expansion),
|
149 |
+
)
|
150 |
+
|
151 |
+
layers = []
|
152 |
+
layers.append(
|
153 |
+
block(
|
154 |
+
self.inplanes,
|
155 |
+
planes,
|
156 |
+
stride,
|
157 |
+
dilation=blocks[0] * dilation,
|
158 |
+
downsample=downsample,
|
159 |
+
BatchNorm=BatchNorm,
|
160 |
+
)
|
161 |
+
)
|
162 |
+
self.inplanes = planes * block.expansion
|
163 |
+
for i in range(1, len(blocks)):
|
164 |
+
layers.append(
|
165 |
+
block(
|
166 |
+
self.inplanes,
|
167 |
+
planes,
|
168 |
+
stride=1,
|
169 |
+
dilation=blocks[i] * dilation,
|
170 |
+
BatchNorm=BatchNorm,
|
171 |
+
)
|
172 |
+
)
|
173 |
+
|
174 |
+
return nn.Sequential(*layers)
|
175 |
+
|
176 |
+
def forward(self, input):
|
177 |
+
x = self.conv1(input)
|
178 |
+
x = self.bn1(x)
|
179 |
+
x = self.relu(x)
|
180 |
+
x = self.maxpool(x)
|
181 |
+
|
182 |
+
x = self.layer1(x)
|
183 |
+
low_level_feat = x
|
184 |
+
x = self.layer2(x)
|
185 |
+
x = self.layer3(x)
|
186 |
+
x = self.layer4(x)
|
187 |
+
return x, low_level_feat
|
188 |
+
|
189 |
+
|
190 |
+
def ResNet101(output_stride=8, BatchNorm=nn.BatchNorm2d, verbose=0, no_init=False):
|
191 |
+
"""Constructs a ResNet-101 model.
|
192 |
+
Args:
|
193 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
194 |
+
"""
|
195 |
+
model = ResNet(
|
196 |
+
Bottleneck,
|
197 |
+
[3, 4, 23, 3],
|
198 |
+
output_stride,
|
199 |
+
BatchNorm,
|
200 |
+
verbose=verbose,
|
201 |
+
no_init=no_init,
|
202 |
+
)
|
203 |
+
return model
|
climategan/deeplab/resnetmulti_v2.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from climategan.blocks import ResBlocks
|
3 |
+
|
4 |
+
affine_par = True
|
5 |
+
|
6 |
+
|
7 |
+
class Bottleneck(nn.Module):
|
8 |
+
expansion = 4
|
9 |
+
|
10 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
|
11 |
+
super(Bottleneck, self).__init__()
|
12 |
+
# change
|
13 |
+
self.conv1 = nn.Conv2d(
|
14 |
+
inplanes, planes, kernel_size=1, stride=stride, bias=False
|
15 |
+
)
|
16 |
+
self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
|
17 |
+
for i in self.bn1.parameters():
|
18 |
+
i.requires_grad = False
|
19 |
+
padding = dilation
|
20 |
+
# change
|
21 |
+
self.conv2 = nn.Conv2d(
|
22 |
+
planes,
|
23 |
+
planes,
|
24 |
+
kernel_size=3,
|
25 |
+
stride=1,
|
26 |
+
padding=padding,
|
27 |
+
bias=False,
|
28 |
+
dilation=dilation,
|
29 |
+
)
|
30 |
+
self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
|
31 |
+
for i in self.bn2.parameters():
|
32 |
+
i.requires_grad = False
|
33 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
34 |
+
self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
|
35 |
+
for i in self.bn3.parameters():
|
36 |
+
i.requires_grad = False
|
37 |
+
self.relu = nn.ReLU(inplace=True)
|
38 |
+
self.downsample = downsample
|
39 |
+
self.stride = stride
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
residual = x
|
43 |
+
out = self.conv1(x)
|
44 |
+
out = self.bn1(out)
|
45 |
+
out = self.relu(out)
|
46 |
+
out = self.conv2(out)
|
47 |
+
out = self.bn2(out)
|
48 |
+
out = self.relu(out)
|
49 |
+
out = self.conv3(out)
|
50 |
+
out = self.bn3(out)
|
51 |
+
if self.downsample is not None:
|
52 |
+
residual = self.downsample(x)
|
53 |
+
out += residual
|
54 |
+
out = self.relu(out)
|
55 |
+
|
56 |
+
return out
|
57 |
+
|
58 |
+
|
59 |
+
class ResNetMulti(nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
layers,
|
63 |
+
n_res=4,
|
64 |
+
res_norm="instance",
|
65 |
+
activ="lrelu",
|
66 |
+
pad_type="reflect",
|
67 |
+
):
|
68 |
+
self.inplanes = 64
|
69 |
+
block = Bottleneck
|
70 |
+
super(ResNetMulti, self).__init__()
|
71 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
72 |
+
self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
|
73 |
+
for i in self.bn1.parameters():
|
74 |
+
i.requires_grad = False
|
75 |
+
self.relu = nn.ReLU(inplace=True)
|
76 |
+
self.maxpool = nn.MaxPool2d(
|
77 |
+
kernel_size=3, stride=2, padding=0, ceil_mode=True
|
78 |
+
) # changed padding from 1 to 0
|
79 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
80 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
81 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
|
82 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
|
83 |
+
|
84 |
+
for m in self.modules():
|
85 |
+
if isinstance(m, nn.Conv2d):
|
86 |
+
m.weight.data.normal_(0, 0.01)
|
87 |
+
elif isinstance(m, nn.BatchNorm2d):
|
88 |
+
m.weight.data.fill_(1)
|
89 |
+
m.bias.data.zero_()
|
90 |
+
self.layer_res = ResBlocks(
|
91 |
+
n_res, 2048, norm=res_norm, activation=activ, pad_type=pad_type
|
92 |
+
)
|
93 |
+
|
94 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
|
95 |
+
downsample = None
|
96 |
+
if (
|
97 |
+
stride != 1
|
98 |
+
or self.inplanes != planes * block.expansion
|
99 |
+
or dilation == 2
|
100 |
+
or dilation == 4
|
101 |
+
):
|
102 |
+
downsample = nn.Sequential(
|
103 |
+
nn.Conv2d(
|
104 |
+
self.inplanes,
|
105 |
+
planes * block.expansion,
|
106 |
+
kernel_size=1,
|
107 |
+
stride=stride,
|
108 |
+
bias=False,
|
109 |
+
),
|
110 |
+
nn.BatchNorm2d(planes * block.expansion, affine=affine_par),
|
111 |
+
)
|
112 |
+
for i in downsample._modules["1"].parameters():
|
113 |
+
i.requires_grad = False
|
114 |
+
layers = []
|
115 |
+
layers.append(
|
116 |
+
block(
|
117 |
+
self.inplanes, planes, stride, dilation=dilation, downsample=downsample
|
118 |
+
)
|
119 |
+
)
|
120 |
+
self.inplanes = planes * block.expansion
|
121 |
+
for i in range(1, blocks):
|
122 |
+
layers.append(block(self.inplanes, planes, dilation=dilation))
|
123 |
+
|
124 |
+
return nn.Sequential(*layers)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = self.conv1(x)
|
128 |
+
x = self.bn1(x)
|
129 |
+
x = self.relu(x)
|
130 |
+
x = self.maxpool(x)
|
131 |
+
x = self.layer1(x)
|
132 |
+
x = self.layer2(x)
|
133 |
+
x = self.layer3(x)
|
134 |
+
x = self.layer4(x)
|
135 |
+
x = self.layer_res(x)
|
136 |
+
return x
|
climategan/depth.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from climategan.blocks import BaseDecoder, Conv2dBlock, InterpolateNearest2d
|
6 |
+
from climategan.utils import find_target_size
|
7 |
+
|
8 |
+
|
9 |
+
def create_depth_decoder(opts, no_init=False, verbose=0):
|
10 |
+
if opts.gen.d.architecture == "base":
|
11 |
+
decoder = BaseDepthDecoder(opts)
|
12 |
+
if "s" in opts.task:
|
13 |
+
assert opts.gen.s.use_dada is False
|
14 |
+
if "m" in opts.tasks:
|
15 |
+
assert opts.gen.m.use_dada is False
|
16 |
+
else:
|
17 |
+
decoder = DADADepthDecoder(opts)
|
18 |
+
|
19 |
+
if verbose > 0:
|
20 |
+
print(f" - Add {decoder.__class__.__name__}")
|
21 |
+
|
22 |
+
return decoder
|
23 |
+
|
24 |
+
|
25 |
+
class DADADepthDecoder(nn.Module):
|
26 |
+
"""
|
27 |
+
Depth decoder based on depth auxiliary task in DADA paper
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, opts):
|
31 |
+
super().__init__()
|
32 |
+
if (
|
33 |
+
opts.gen.encoder.architecture == "deeplabv3"
|
34 |
+
and opts.gen.deeplabv3.backbone == "mobilenet"
|
35 |
+
):
|
36 |
+
res_dim = 320
|
37 |
+
else:
|
38 |
+
res_dim = 2048
|
39 |
+
|
40 |
+
mid_dim = 512
|
41 |
+
|
42 |
+
self.do_feat_fusion = False
|
43 |
+
if opts.gen.m.use_dada or ("s" in opts.tasks and opts.gen.s.use_dada):
|
44 |
+
self.do_feat_fusion = True
|
45 |
+
self.dec4 = Conv2dBlock(
|
46 |
+
128,
|
47 |
+
res_dim,
|
48 |
+
1,
|
49 |
+
stride=1,
|
50 |
+
padding=0,
|
51 |
+
bias=True,
|
52 |
+
activation="lrelu",
|
53 |
+
norm="none",
|
54 |
+
)
|
55 |
+
|
56 |
+
self.relu = nn.ReLU(inplace=True)
|
57 |
+
self.enc4_1 = Conv2dBlock(
|
58 |
+
res_dim,
|
59 |
+
mid_dim,
|
60 |
+
1,
|
61 |
+
stride=1,
|
62 |
+
padding=0,
|
63 |
+
bias=False,
|
64 |
+
activation="lrelu",
|
65 |
+
pad_type="reflect",
|
66 |
+
norm="batch",
|
67 |
+
)
|
68 |
+
self.enc4_2 = Conv2dBlock(
|
69 |
+
mid_dim,
|
70 |
+
mid_dim,
|
71 |
+
3,
|
72 |
+
stride=1,
|
73 |
+
padding=1,
|
74 |
+
bias=False,
|
75 |
+
activation="lrelu",
|
76 |
+
pad_type="reflect",
|
77 |
+
norm="batch",
|
78 |
+
)
|
79 |
+
self.enc4_3 = Conv2dBlock(
|
80 |
+
mid_dim,
|
81 |
+
128,
|
82 |
+
1,
|
83 |
+
stride=1,
|
84 |
+
padding=0,
|
85 |
+
bias=False,
|
86 |
+
activation="lrelu",
|
87 |
+
pad_type="reflect",
|
88 |
+
norm="batch",
|
89 |
+
)
|
90 |
+
self.upsample = None
|
91 |
+
if opts.gen.d.upsample_featuremaps:
|
92 |
+
self.upsample = nn.Sequential(
|
93 |
+
*[
|
94 |
+
InterpolateNearest2d(),
|
95 |
+
Conv2dBlock(
|
96 |
+
128,
|
97 |
+
32,
|
98 |
+
3,
|
99 |
+
stride=1,
|
100 |
+
padding=1,
|
101 |
+
bias=False,
|
102 |
+
activation="lrelu",
|
103 |
+
pad_type="reflect",
|
104 |
+
norm="batch",
|
105 |
+
),
|
106 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
107 |
+
]
|
108 |
+
)
|
109 |
+
self._target_size = find_target_size(opts, "d")
|
110 |
+
print(
|
111 |
+
" - {}: setting target size to {}".format(
|
112 |
+
self.__class__.__name__, self._target_size
|
113 |
+
)
|
114 |
+
)
|
115 |
+
|
116 |
+
def set_target_size(self, size):
|
117 |
+
"""
|
118 |
+
Set final interpolation's target size
|
119 |
+
|
120 |
+
Args:
|
121 |
+
size (int, list, tuple): target size (h, w). If int, target will be (i, i)
|
122 |
+
"""
|
123 |
+
if isinstance(size, (list, tuple)):
|
124 |
+
self._target_size = size[:2]
|
125 |
+
else:
|
126 |
+
self._target_size = (size, size)
|
127 |
+
|
128 |
+
def forward(self, z):
|
129 |
+
if isinstance(z, (list, tuple)):
|
130 |
+
z = z[0]
|
131 |
+
z4_enc = self.enc4_1(z)
|
132 |
+
z4_enc = self.enc4_2(z4_enc)
|
133 |
+
z4_enc = self.enc4_3(z4_enc)
|
134 |
+
|
135 |
+
z_depth = None
|
136 |
+
if self.do_feat_fusion:
|
137 |
+
z_depth = self.dec4(z4_enc)
|
138 |
+
|
139 |
+
if self.upsample is not None:
|
140 |
+
z4_enc = self.upsample(z4_enc)
|
141 |
+
|
142 |
+
depth = torch.mean(z4_enc, dim=1, keepdim=True) # DADA paper decoder
|
143 |
+
if depth.shape[-1] != self._target_size:
|
144 |
+
depth = F.interpolate(
|
145 |
+
depth,
|
146 |
+
size=(384, 384), # size used in MiDaS inference
|
147 |
+
mode="bicubic", # what MiDaS uses
|
148 |
+
align_corners=False,
|
149 |
+
)
|
150 |
+
|
151 |
+
depth = F.interpolate(
|
152 |
+
depth, (self._target_size, self._target_size), mode="nearest"
|
153 |
+
) # what we used in the transforms to resize input
|
154 |
+
|
155 |
+
return depth, z_depth
|
156 |
+
|
157 |
+
def __str__(self):
|
158 |
+
return "DADA Depth Decoder"
|
159 |
+
|
160 |
+
|
161 |
+
class BaseDepthDecoder(BaseDecoder):
|
162 |
+
def __init__(self, opts):
|
163 |
+
low_level_feats_dim = -1
|
164 |
+
use_v3 = opts.gen.encoder.architecture == "deeplabv3"
|
165 |
+
use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet"
|
166 |
+
use_low = opts.gen.d.use_low_level_feats
|
167 |
+
|
168 |
+
if use_v3 and use_mobile_net:
|
169 |
+
input_dim = 320
|
170 |
+
if use_low:
|
171 |
+
low_level_feats_dim = 24
|
172 |
+
elif use_v3:
|
173 |
+
input_dim = 2048
|
174 |
+
if use_low:
|
175 |
+
low_level_feats_dim = 256
|
176 |
+
else:
|
177 |
+
input_dim = 2048
|
178 |
+
|
179 |
+
n_upsample = 1 if opts.gen.d.upsample_featuremaps else 0
|
180 |
+
output_dim = (
|
181 |
+
1
|
182 |
+
if not opts.gen.d.classify.enable
|
183 |
+
else opts.gen.d.classify.linspace.buckets
|
184 |
+
)
|
185 |
+
|
186 |
+
self._target_size = find_target_size(opts, "d")
|
187 |
+
print(
|
188 |
+
" - {}: setting target size to {}".format(
|
189 |
+
self.__class__.__name__, self._target_size
|
190 |
+
)
|
191 |
+
)
|
192 |
+
|
193 |
+
super().__init__(
|
194 |
+
n_upsample=n_upsample,
|
195 |
+
n_res=opts.gen.d.n_res,
|
196 |
+
input_dim=input_dim,
|
197 |
+
proj_dim=opts.gen.d.proj_dim,
|
198 |
+
output_dim=output_dim,
|
199 |
+
norm=opts.gen.d.norm,
|
200 |
+
activ=opts.gen.d.activ,
|
201 |
+
pad_type=opts.gen.d.pad_type,
|
202 |
+
output_activ="none",
|
203 |
+
low_level_feats_dim=low_level_feats_dim,
|
204 |
+
)
|
205 |
+
|
206 |
+
def set_target_size(self, size):
|
207 |
+
"""
|
208 |
+
Set final interpolation's target size
|
209 |
+
|
210 |
+
Args:
|
211 |
+
size (int, list, tuple): target size (h, w). If int, target will be (i, i)
|
212 |
+
"""
|
213 |
+
if isinstance(size, (list, tuple)):
|
214 |
+
self._target_size = size[:2]
|
215 |
+
else:
|
216 |
+
self._target_size = (size, size)
|
217 |
+
|
218 |
+
def forward(self, z, cond=None):
|
219 |
+
if self._target_size is None:
|
220 |
+
error = "self._target_size should be set with self.set_target_size()"
|
221 |
+
error += "to interpolate depth to the target depth map's size"
|
222 |
+
raise ValueError(error)
|
223 |
+
|
224 |
+
d = super().forward(z)
|
225 |
+
|
226 |
+
preds = F.interpolate(
|
227 |
+
d, size=self._target_size, mode="bilinear", align_corners=True
|
228 |
+
)
|
229 |
+
|
230 |
+
return preds, None
|
climategan/discriminator.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Discriminator architecture for ClimateGAN's GAN components (a and t)
|
2 |
+
"""
|
3 |
+
import functools
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from climategan.blocks import SpectralNorm
|
9 |
+
from climategan.tutils import init_weights
|
10 |
+
|
11 |
+
# from torch.optim import lr_scheduler
|
12 |
+
|
13 |
+
# mainly from https://github.com/sangwoomo/instagan/blob/master/models/networks.py
|
14 |
+
|
15 |
+
|
16 |
+
def create_discriminator(opts, device, no_init=False, verbose=0):
|
17 |
+
disc = OmniDiscriminator(opts)
|
18 |
+
if no_init:
|
19 |
+
return disc
|
20 |
+
|
21 |
+
for task, model in disc.items():
|
22 |
+
if isinstance(model, nn.ModuleDict):
|
23 |
+
for domain, domain_model in model.items():
|
24 |
+
init_weights(
|
25 |
+
domain_model,
|
26 |
+
init_type=opts.dis[task].init_type,
|
27 |
+
init_gain=opts.dis[task].init_gain,
|
28 |
+
verbose=verbose,
|
29 |
+
caller=f"create_discriminator {task} {domain}",
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
init_weights(
|
33 |
+
model,
|
34 |
+
init_type=opts.dis[task].init_type,
|
35 |
+
init_gain=opts.dis[task].init_gain,
|
36 |
+
verbose=verbose,
|
37 |
+
caller=f"create_discriminator {task}",
|
38 |
+
)
|
39 |
+
return disc.to(device)
|
40 |
+
|
41 |
+
|
42 |
+
def define_D(
|
43 |
+
input_nc,
|
44 |
+
ndf,
|
45 |
+
n_layers=3,
|
46 |
+
norm="batch",
|
47 |
+
use_sigmoid=False,
|
48 |
+
get_intermediate_features=False,
|
49 |
+
num_D=1,
|
50 |
+
):
|
51 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
52 |
+
net = MultiscaleDiscriminator(
|
53 |
+
input_nc,
|
54 |
+
ndf,
|
55 |
+
n_layers=n_layers,
|
56 |
+
norm_layer=norm_layer,
|
57 |
+
use_sigmoid=use_sigmoid,
|
58 |
+
get_intermediate_features=get_intermediate_features,
|
59 |
+
num_D=num_D,
|
60 |
+
)
|
61 |
+
return net
|
62 |
+
|
63 |
+
|
64 |
+
def get_norm_layer(norm_type="instance"):
|
65 |
+
if not norm_type:
|
66 |
+
print("norm_type is {}, defaulting to instance")
|
67 |
+
norm_type = "instance"
|
68 |
+
if norm_type == "batch":
|
69 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
70 |
+
elif norm_type == "instance":
|
71 |
+
norm_layer = functools.partial(
|
72 |
+
nn.InstanceNorm2d, affine=False, track_running_stats=False
|
73 |
+
)
|
74 |
+
elif norm_type == "none":
|
75 |
+
norm_layer = None
|
76 |
+
else:
|
77 |
+
raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
|
78 |
+
return norm_layer
|
79 |
+
|
80 |
+
|
81 |
+
# Defines the PatchGAN discriminator with the specified arguments.
|
82 |
+
class NLayerDiscriminator(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
input_nc=3,
|
86 |
+
ndf=64,
|
87 |
+
n_layers=3,
|
88 |
+
norm_layer=nn.BatchNorm2d,
|
89 |
+
use_sigmoid=False,
|
90 |
+
get_intermediate_features=True,
|
91 |
+
):
|
92 |
+
super(NLayerDiscriminator, self).__init__()
|
93 |
+
if type(norm_layer) == functools.partial:
|
94 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
95 |
+
else:
|
96 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
97 |
+
|
98 |
+
self.get_intermediate_features = get_intermediate_features
|
99 |
+
|
100 |
+
kw = 4
|
101 |
+
padw = 1
|
102 |
+
sequence = [
|
103 |
+
[
|
104 |
+
# Use spectral normalization
|
105 |
+
SpectralNorm(
|
106 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)
|
107 |
+
),
|
108 |
+
nn.LeakyReLU(0.2, True),
|
109 |
+
]
|
110 |
+
]
|
111 |
+
|
112 |
+
nf_mult = 1
|
113 |
+
nf_mult_prev = 1
|
114 |
+
for n in range(1, n_layers):
|
115 |
+
nf_mult_prev = nf_mult
|
116 |
+
nf_mult = min(2 ** n, 8)
|
117 |
+
sequence += [
|
118 |
+
[
|
119 |
+
# Use spectral normalization
|
120 |
+
SpectralNorm( # TODO replace with Conv2dBlock
|
121 |
+
nn.Conv2d(
|
122 |
+
ndf * nf_mult_prev,
|
123 |
+
ndf * nf_mult,
|
124 |
+
kernel_size=kw,
|
125 |
+
stride=2,
|
126 |
+
padding=padw,
|
127 |
+
bias=use_bias,
|
128 |
+
)
|
129 |
+
),
|
130 |
+
norm_layer(ndf * nf_mult),
|
131 |
+
nn.LeakyReLU(0.2, True),
|
132 |
+
]
|
133 |
+
]
|
134 |
+
|
135 |
+
nf_mult_prev = nf_mult
|
136 |
+
nf_mult = min(2 ** n_layers, 8)
|
137 |
+
sequence += [
|
138 |
+
[
|
139 |
+
# Use spectral normalization
|
140 |
+
SpectralNorm(
|
141 |
+
nn.Conv2d(
|
142 |
+
ndf * nf_mult_prev,
|
143 |
+
ndf * nf_mult,
|
144 |
+
kernel_size=kw,
|
145 |
+
stride=1,
|
146 |
+
padding=padw,
|
147 |
+
bias=use_bias,
|
148 |
+
)
|
149 |
+
),
|
150 |
+
norm_layer(ndf * nf_mult),
|
151 |
+
nn.LeakyReLU(0.2, True),
|
152 |
+
]
|
153 |
+
]
|
154 |
+
|
155 |
+
# Use spectral normalization
|
156 |
+
sequence += [
|
157 |
+
[
|
158 |
+
SpectralNorm(
|
159 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
160 |
+
)
|
161 |
+
]
|
162 |
+
]
|
163 |
+
|
164 |
+
if use_sigmoid:
|
165 |
+
sequence += [[nn.Sigmoid()]]
|
166 |
+
|
167 |
+
# We divide the layers into groups to extract intermediate layer outputs
|
168 |
+
for n in range(len(sequence)):
|
169 |
+
self.add_module("model" + str(n), nn.Sequential(*sequence[n]))
|
170 |
+
# self.model = nn.Sequential(*sequence)
|
171 |
+
|
172 |
+
def forward(self, input):
|
173 |
+
results = [input]
|
174 |
+
for submodel in self.children():
|
175 |
+
intermediate_output = submodel(results[-1])
|
176 |
+
results.append(intermediate_output)
|
177 |
+
|
178 |
+
get_intermediate_features = self.get_intermediate_features
|
179 |
+
if get_intermediate_features:
|
180 |
+
return results[1:]
|
181 |
+
else:
|
182 |
+
return results[-1]
|
183 |
+
|
184 |
+
|
185 |
+
# def forward(self, input):
|
186 |
+
# return self.model(input)
|
187 |
+
|
188 |
+
|
189 |
+
# Source: https://github.com/NVIDIA/pix2pixHD
|
190 |
+
class MultiscaleDiscriminator(nn.Module):
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
input_nc=3,
|
194 |
+
ndf=64,
|
195 |
+
n_layers=3,
|
196 |
+
norm_layer=nn.BatchNorm2d,
|
197 |
+
use_sigmoid=False,
|
198 |
+
get_intermediate_features=True,
|
199 |
+
num_D=3,
|
200 |
+
):
|
201 |
+
super(MultiscaleDiscriminator, self).__init__()
|
202 |
+
# self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
|
203 |
+
# use_sigmoid=False, num_D=3, getIntermFeat=False
|
204 |
+
|
205 |
+
self.n_layers = n_layers
|
206 |
+
self.ndf = ndf
|
207 |
+
self.norm_layer = norm_layer
|
208 |
+
self.use_sigmoid = use_sigmoid
|
209 |
+
self.get_intermediate_features = get_intermediate_features
|
210 |
+
self.num_D = num_D
|
211 |
+
|
212 |
+
for i in range(self.num_D):
|
213 |
+
netD = NLayerDiscriminator(
|
214 |
+
input_nc=input_nc,
|
215 |
+
ndf=self.ndf,
|
216 |
+
n_layers=self.n_layers,
|
217 |
+
norm_layer=self.norm_layer,
|
218 |
+
use_sigmoid=self.use_sigmoid,
|
219 |
+
get_intermediate_features=self.get_intermediate_features,
|
220 |
+
)
|
221 |
+
self.add_module("discriminator_%d" % i, netD)
|
222 |
+
|
223 |
+
self.downsample = nn.AvgPool2d(
|
224 |
+
3, stride=2, padding=[1, 1], count_include_pad=False
|
225 |
+
)
|
226 |
+
|
227 |
+
def forward(self, input):
|
228 |
+
result = []
|
229 |
+
get_intermediate_features = self.get_intermediate_features
|
230 |
+
for name, D in self.named_children():
|
231 |
+
if "discriminator" not in name:
|
232 |
+
continue
|
233 |
+
out = D(input)
|
234 |
+
if not get_intermediate_features:
|
235 |
+
out = [out]
|
236 |
+
result.append(out)
|
237 |
+
input = self.downsample(input)
|
238 |
+
|
239 |
+
return result
|
240 |
+
|
241 |
+
|
242 |
+
class OmniDiscriminator(nn.ModuleDict):
|
243 |
+
def __init__(self, opts):
|
244 |
+
super().__init__()
|
245 |
+
if "p" in opts.tasks:
|
246 |
+
if opts.dis.p.use_local_discriminator:
|
247 |
+
|
248 |
+
self["p"] = nn.ModuleDict(
|
249 |
+
{
|
250 |
+
"global": define_D(
|
251 |
+
input_nc=3,
|
252 |
+
ndf=opts.dis.p.ndf,
|
253 |
+
n_layers=opts.dis.p.n_layers,
|
254 |
+
norm=opts.dis.p.norm,
|
255 |
+
use_sigmoid=opts.dis.p.use_sigmoid,
|
256 |
+
get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
|
257 |
+
num_D=opts.dis.p.num_D,
|
258 |
+
),
|
259 |
+
"local": define_D(
|
260 |
+
input_nc=3,
|
261 |
+
ndf=opts.dis.p.ndf,
|
262 |
+
n_layers=opts.dis.p.n_layers,
|
263 |
+
norm=opts.dis.p.norm,
|
264 |
+
use_sigmoid=opts.dis.p.use_sigmoid,
|
265 |
+
get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
|
266 |
+
num_D=opts.dis.p.num_D,
|
267 |
+
),
|
268 |
+
}
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
self["p"] = define_D(
|
272 |
+
input_nc=4, # image + mask
|
273 |
+
ndf=opts.dis.p.ndf,
|
274 |
+
n_layers=opts.dis.p.n_layers,
|
275 |
+
norm=opts.dis.p.norm,
|
276 |
+
use_sigmoid=opts.dis.p.use_sigmoid,
|
277 |
+
get_intermediate_features=opts.dis.p.get_intermediate_features,
|
278 |
+
num_D=opts.dis.p.num_D,
|
279 |
+
)
|
280 |
+
if "m" in opts.tasks:
|
281 |
+
if opts.gen.m.use_advent:
|
282 |
+
if opts.dis.m.architecture == "base":
|
283 |
+
if opts.dis.m.gan_type == "WGAN_norm":
|
284 |
+
self["m"] = nn.ModuleDict(
|
285 |
+
{
|
286 |
+
"Advent": get_fc_discriminator(
|
287 |
+
num_classes=2, use_norm=True
|
288 |
+
)
|
289 |
+
}
|
290 |
+
)
|
291 |
+
else:
|
292 |
+
self["m"] = nn.ModuleDict(
|
293 |
+
{
|
294 |
+
"Advent": get_fc_discriminator(
|
295 |
+
num_classes=2, use_norm=False
|
296 |
+
)
|
297 |
+
}
|
298 |
+
)
|
299 |
+
elif opts.dis.m.architecture == "OmniDiscriminator":
|
300 |
+
self["m"] = nn.ModuleDict(
|
301 |
+
{
|
302 |
+
"Advent": define_D(
|
303 |
+
input_nc=2,
|
304 |
+
ndf=opts.dis.m.ndf,
|
305 |
+
n_layers=opts.dis.m.n_layers,
|
306 |
+
norm=opts.dis.m.norm,
|
307 |
+
use_sigmoid=opts.dis.m.use_sigmoid,
|
308 |
+
get_intermediate_features=opts.dis.m.get_intermediate_features, # noqa: E501
|
309 |
+
num_D=opts.dis.m.num_D,
|
310 |
+
)
|
311 |
+
}
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
raise Exception("This Discriminator is currently not supported!")
|
315 |
+
if "s" in opts.tasks:
|
316 |
+
if opts.gen.s.use_advent:
|
317 |
+
if opts.dis.s.gan_type == "WGAN_norm":
|
318 |
+
self["s"] = nn.ModuleDict(
|
319 |
+
{"Advent": get_fc_discriminator(num_classes=11, use_norm=True)}
|
320 |
+
)
|
321 |
+
else:
|
322 |
+
self["s"] = nn.ModuleDict(
|
323 |
+
{"Advent": get_fc_discriminator(num_classes=11, use_norm=False)}
|
324 |
+
)
|
325 |
+
|
326 |
+
|
327 |
+
def get_fc_discriminator(num_classes=2, ndf=64, use_norm=False):
|
328 |
+
if use_norm:
|
329 |
+
return torch.nn.Sequential(
|
330 |
+
SpectralNorm(
|
331 |
+
torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
|
332 |
+
),
|
333 |
+
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
334 |
+
SpectralNorm(
|
335 |
+
torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)
|
336 |
+
),
|
337 |
+
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
338 |
+
SpectralNorm(
|
339 |
+
torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
|
340 |
+
),
|
341 |
+
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
342 |
+
SpectralNorm(
|
343 |
+
torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
|
344 |
+
),
|
345 |
+
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
346 |
+
SpectralNorm(
|
347 |
+
torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1)
|
348 |
+
),
|
349 |
+
)
|
350 |
+
else:
|
351 |
+
return torch.nn.Sequential(
|
352 |
+
torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1),
|
353 |
+
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
354 |
+
torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
|
355 |
+
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
356 |
+
torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
|
357 |
+
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
358 |
+
torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
|
359 |
+
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
360 |
+
torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1),
|
361 |
+
)
|
climategan/eval_metrics.py
ADDED
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from skimage import filters
|
5 |
+
from sklearn.metrics.pairwise import euclidean_distances
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import seaborn as sns
|
8 |
+
from copy import deepcopy
|
9 |
+
|
10 |
+
# ------------------------------------------------------------------------------
|
11 |
+
# ----- Evaluation metrics for a pair of binary mask images (pred, target) -----
|
12 |
+
# ------------------------------------------------------------------------------
|
13 |
+
|
14 |
+
|
15 |
+
def get_accuracy(arr1, arr2):
|
16 |
+
"""pixel accuracy
|
17 |
+
|
18 |
+
Args:
|
19 |
+
arr1 (np.array)
|
20 |
+
arr2 (np.array)
|
21 |
+
"""
|
22 |
+
return (arr1 == arr2).sum() / arr1.size
|
23 |
+
|
24 |
+
|
25 |
+
def trimap(pred_im, gt_im, thickness=8):
|
26 |
+
"""Compute accuracy in a region of thickness around the contours
|
27 |
+
for binary images (0-1 values)
|
28 |
+
Args:
|
29 |
+
pred_im (Image): Prediction
|
30 |
+
gt_im (Image): Target
|
31 |
+
thickness (int, optional): [description]. Defaults to 8.
|
32 |
+
"""
|
33 |
+
W, H = gt_im.size
|
34 |
+
contours, hierarchy = cv2.findContours(
|
35 |
+
np.array(gt_im), mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE
|
36 |
+
)
|
37 |
+
mask_contour = np.zeros((H, W), dtype=np.int32)
|
38 |
+
cv2.drawContours(
|
39 |
+
mask_contour, contours, -1, (1), thickness=thickness, hierarchy=hierarchy
|
40 |
+
)
|
41 |
+
gt_contour = np.array(gt_im)[np.where(mask_contour > 0)]
|
42 |
+
pred_contour = np.array(pred_im)[np.where(mask_contour > 0)]
|
43 |
+
return get_accuracy(pred_contour, gt_contour)
|
44 |
+
|
45 |
+
|
46 |
+
def iou(pred_im, gt_im):
|
47 |
+
"""
|
48 |
+
IoU for binary masks (0-1 values)
|
49 |
+
|
50 |
+
Args:
|
51 |
+
pred_im ([type]): [description]
|
52 |
+
gt_im ([type]): [description]
|
53 |
+
"""
|
54 |
+
pred = np.array(pred_im)
|
55 |
+
gt = np.array(gt_im)
|
56 |
+
intersection = (pred * gt).sum()
|
57 |
+
union = (pred + gt).sum() - intersection
|
58 |
+
return intersection / union
|
59 |
+
|
60 |
+
|
61 |
+
def f1_score(pred_im, gt_im):
|
62 |
+
pred = np.array(pred_im)
|
63 |
+
gt = np.array(gt_im)
|
64 |
+
intersection = (pred * gt).sum()
|
65 |
+
return 2 * intersection / (pred + gt).sum()
|
66 |
+
|
67 |
+
|
68 |
+
def accuracy(pred_im, gt_im):
|
69 |
+
pred = np.array(pred_im)
|
70 |
+
gt = np.array(gt_im)
|
71 |
+
if len(gt_im.shape) == 4:
|
72 |
+
assert gt_im.shape[1] == 1
|
73 |
+
gt_im = gt_im[:, 0, :, :]
|
74 |
+
if len(pred.shape) > len(gt_im.shape):
|
75 |
+
pred = np.argmax(pred, axis=1)
|
76 |
+
return float((pred == gt).sum()) / gt.size
|
77 |
+
|
78 |
+
|
79 |
+
def mIOU(pred, label, average="macro"):
|
80 |
+
"""
|
81 |
+
Adapted from:
|
82 |
+
https://stackoverflow.com/questions/62461379/multiclass-semantic-segmentation-model-evaluation
|
83 |
+
|
84 |
+
Compute the mean IOU from pred and label tensors
|
85 |
+
pred is a tensor N x C x H x W with logits (softmax will be applied)
|
86 |
+
and label is a N x H x W tensor with int labels per pixel
|
87 |
+
|
88 |
+
this does the same as sklearn's jaccard_score function if you choose average="macro"
|
89 |
+
Args:
|
90 |
+
pred (torch.tensor): predicted logits
|
91 |
+
label (torch.tensor): labels
|
92 |
+
average: "macro" or "weighted"
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
float: mIOU, can be nan
|
96 |
+
"""
|
97 |
+
num_classes = pred.shape[-3]
|
98 |
+
|
99 |
+
pred = torch.argmax(pred, dim=1).squeeze(1)
|
100 |
+
present_iou_list = list()
|
101 |
+
pred = pred.view(-1)
|
102 |
+
label = label.view(-1)
|
103 |
+
# Note: Following for loop goes from 0 to (num_classes-1)
|
104 |
+
# and ignore_index is num_classes, thus ignore_index is
|
105 |
+
# not considered in computation of IoU.
|
106 |
+
interesting_classes = (
|
107 |
+
[*range(num_classes)] if num_classes > 2 else [int(label.max().item())]
|
108 |
+
)
|
109 |
+
weights = []
|
110 |
+
|
111 |
+
for sem_class in interesting_classes:
|
112 |
+
pred_inds = pred == sem_class
|
113 |
+
target_inds = label == sem_class
|
114 |
+
if (target_inds.long().sum().item() > 0) or (pred_inds.long().sum().item() > 0):
|
115 |
+
intersection_now = (pred_inds[target_inds]).long().sum().item()
|
116 |
+
union_now = (
|
117 |
+
pred_inds.long().sum().item()
|
118 |
+
+ target_inds.long().sum().item()
|
119 |
+
- intersection_now
|
120 |
+
)
|
121 |
+
weights.append(pred_inds.long().sum().item())
|
122 |
+
iou_now = float(intersection_now) / float(union_now)
|
123 |
+
present_iou_list.append(iou_now)
|
124 |
+
if not present_iou_list:
|
125 |
+
return float("nan")
|
126 |
+
elif average == "weighted":
|
127 |
+
weighted_avg = np.sum(np.multiply(weights, present_iou_list) / np.sum(weights))
|
128 |
+
return weighted_avg
|
129 |
+
else:
|
130 |
+
return np.mean(present_iou_list)
|
131 |
+
|
132 |
+
|
133 |
+
def masker_classification_metrics(
|
134 |
+
pred, label, labels_dict={"cannot": 0, "must": 1, "may": 2}
|
135 |
+
):
|
136 |
+
"""
|
137 |
+
Classification metrics for the masker, and the corresponding maps. If the
|
138 |
+
predictions are soft, the errors are weighted accordingly. Metrics computed:
|
139 |
+
|
140 |
+
tpr : float
|
141 |
+
True positive rate
|
142 |
+
|
143 |
+
tpt : float
|
144 |
+
True positive total (divided by total population)
|
145 |
+
|
146 |
+
tnr : float
|
147 |
+
True negative rate
|
148 |
+
|
149 |
+
tnt : float
|
150 |
+
True negative total (divided by total population)
|
151 |
+
|
152 |
+
fpr : float
|
153 |
+
False positive rate: rate of predicted mask on cannot flood
|
154 |
+
|
155 |
+
fpt : float
|
156 |
+
False positive total (divided by total population)
|
157 |
+
|
158 |
+
fnr : float
|
159 |
+
False negative rate: rate of missed mask on must flood
|
160 |
+
|
161 |
+
fnt : float
|
162 |
+
False negative total (divided by total population)
|
163 |
+
|
164 |
+
mnr : float
|
165 |
+
"May" negative rate (labeled as "may", predicted as no-mask)
|
166 |
+
|
167 |
+
mpr : float
|
168 |
+
"May" positive rate (labeled as "may", predicted as mask)
|
169 |
+
|
170 |
+
accuracy : float
|
171 |
+
Accuracy
|
172 |
+
|
173 |
+
error : float
|
174 |
+
Error
|
175 |
+
|
176 |
+
precision : float
|
177 |
+
Precision, considering only cannot and must flood labels
|
178 |
+
|
179 |
+
f05 : float
|
180 |
+
F0.5 score, considering only cannot and must flood labels
|
181 |
+
|
182 |
+
accuracy_must_may : float
|
183 |
+
Accuracy considering only the must and may areas
|
184 |
+
|
185 |
+
Parameters
|
186 |
+
----------
|
187 |
+
pred : array-like
|
188 |
+
Mask prediction
|
189 |
+
|
190 |
+
label : array-like
|
191 |
+
Mask ground truth labels
|
192 |
+
|
193 |
+
labels_dict : dict
|
194 |
+
A dictionary with the identifier of each class (cannot, must, may)
|
195 |
+
|
196 |
+
Returns
|
197 |
+
-------
|
198 |
+
metrics_dict : dict
|
199 |
+
A dictionary with metric name and value pairs
|
200 |
+
|
201 |
+
maps_dict : dict
|
202 |
+
A dictionary containing the metric maps
|
203 |
+
"""
|
204 |
+
tp_map = pred * np.asarray(label == labels_dict["must"], dtype=int)
|
205 |
+
tpr = np.sum(tp_map) / np.sum(label == labels_dict["must"])
|
206 |
+
tpt = np.sum(tp_map) / np.prod(label.shape)
|
207 |
+
tn_map = (1.0 - pred) * np.asarray(label == labels_dict["cannot"], dtype=int)
|
208 |
+
tnr = np.sum(tn_map) / np.sum(label == labels_dict["cannot"])
|
209 |
+
tnt = np.sum(tn_map) / np.prod(label.shape)
|
210 |
+
fp_map = pred * np.asarray(label == labels_dict["cannot"], dtype=int)
|
211 |
+
fpr = np.sum(fp_map) / np.sum(label == labels_dict["cannot"])
|
212 |
+
fpt = np.sum(fp_map) / np.prod(label.shape)
|
213 |
+
fn_map = (1.0 - pred) * np.asarray(label == labels_dict["must"], dtype=int)
|
214 |
+
fnr = np.sum(fn_map) / np.sum(label == labels_dict["must"])
|
215 |
+
fnt = np.sum(fn_map) / np.prod(label.shape)
|
216 |
+
may_neg_map = (1.0 - pred) * np.asarray(label == labels_dict["may"], dtype=int)
|
217 |
+
may_pos_map = pred * np.asarray(label == labels_dict["may"], dtype=int)
|
218 |
+
mnr = np.sum(may_neg_map) / np.sum(label == labels_dict["may"])
|
219 |
+
mpr = np.sum(may_pos_map) / np.sum(label == labels_dict["may"])
|
220 |
+
accuracy = tpt + tnt
|
221 |
+
error = fpt + fnt
|
222 |
+
|
223 |
+
# Assertions
|
224 |
+
assert np.isclose(tpr, 1.0 - fnr), "TPR: {:.4f}, FNR: {:.4f}".format(tpr, fnr)
|
225 |
+
assert np.isclose(tnr, 1.0 - fpr), "TNR: {:.4f}, FPR: {:.4f}".format(tnr, fpr)
|
226 |
+
assert np.isclose(mpr, 1.0 - mnr), "MPR: {:.4f}, MNR: {:.4f}".format(mpr, mnr)
|
227 |
+
|
228 |
+
precision = np.sum(tp_map) / (np.sum(tp_map) + np.sum(fp_map) + 1e-9)
|
229 |
+
beta = 0.5
|
230 |
+
f05 = ((1 + beta ** 2) * precision * tpr) / (beta ** 2 * precision + tpr + 1e-9)
|
231 |
+
accuracy_must_may = (np.sum(tp_map) + np.sum(may_neg_map)) / (
|
232 |
+
np.sum(label == labels_dict["must"]) + np.sum(label == labels_dict["may"])
|
233 |
+
)
|
234 |
+
|
235 |
+
metrics_dict = {
|
236 |
+
"tpr": tpr,
|
237 |
+
"tpt": tpt,
|
238 |
+
"tnr": tnr,
|
239 |
+
"tnt": tnt,
|
240 |
+
"fpr": fpr,
|
241 |
+
"fpt": fpt,
|
242 |
+
"fnr": fnr,
|
243 |
+
"fnt": fnt,
|
244 |
+
"mpr": mpr,
|
245 |
+
"mnr": mnr,
|
246 |
+
"accuracy": accuracy,
|
247 |
+
"error": error,
|
248 |
+
"precision": precision,
|
249 |
+
"f05": f05,
|
250 |
+
"accuracy_must_may": accuracy_must_may,
|
251 |
+
}
|
252 |
+
maps_dict = {
|
253 |
+
"tp": tp_map,
|
254 |
+
"tn": tn_map,
|
255 |
+
"fp": fp_map,
|
256 |
+
"fn": fn_map,
|
257 |
+
"may_pos": may_pos_map,
|
258 |
+
"may_neg": may_neg_map,
|
259 |
+
}
|
260 |
+
|
261 |
+
return metrics_dict, maps_dict
|
262 |
+
|
263 |
+
|
264 |
+
def pred_cannot(pred, label, label_cannot=0):
|
265 |
+
"""
|
266 |
+
Metric for the masker: Computes false positive rate and its map. If the
|
267 |
+
predictions are soft, the errors are weighted accordingly.
|
268 |
+
|
269 |
+
Parameters
|
270 |
+
----------
|
271 |
+
pred : array-like
|
272 |
+
Mask prediction
|
273 |
+
|
274 |
+
label : array-like
|
275 |
+
Mask ground truth labels
|
276 |
+
|
277 |
+
label_cannot : int
|
278 |
+
The label index of "cannot flood"
|
279 |
+
|
280 |
+
Returns
|
281 |
+
-------
|
282 |
+
fp_map : array-like
|
283 |
+
The map of false positives: predicted mask on cannot flood
|
284 |
+
|
285 |
+
fpr : float
|
286 |
+
False positive rate: rate of predicted mask on cannot flood
|
287 |
+
"""
|
288 |
+
fp_map = pred * np.asarray(label == label_cannot, dtype=int)
|
289 |
+
fpr = np.sum(fp_map) / np.sum(label == label_cannot)
|
290 |
+
return fp_map, fpr
|
291 |
+
|
292 |
+
|
293 |
+
def missed_must(pred, label, label_must=1):
|
294 |
+
"""
|
295 |
+
Metric for the masker: Computes false negative rate and its map. If the
|
296 |
+
predictions are soft, the errors are weighted accordingly.
|
297 |
+
|
298 |
+
Parameters
|
299 |
+
----------
|
300 |
+
pred : array-like
|
301 |
+
Mask prediction
|
302 |
+
|
303 |
+
label : array-like
|
304 |
+
Mask ground truth labels
|
305 |
+
|
306 |
+
label_must : int
|
307 |
+
The label index of "must flood"
|
308 |
+
|
309 |
+
Returns
|
310 |
+
-------
|
311 |
+
fn_map : array-like
|
312 |
+
The map of false negatives: missed mask on must flood
|
313 |
+
|
314 |
+
fnr : float
|
315 |
+
False negative rate: rate of missed mask on must flood
|
316 |
+
"""
|
317 |
+
fn_map = (1.0 - pred) * np.asarray(label == label_must, dtype=int)
|
318 |
+
fnr = np.sum(fn_map) / np.sum(label == label_must)
|
319 |
+
return fn_map, fnr
|
320 |
+
|
321 |
+
|
322 |
+
def may_flood(pred, label, label_may=2):
|
323 |
+
"""
|
324 |
+
Metric for the masker: Computes "may" negative and "may" positive rates and their
|
325 |
+
map. If the predictions are soft, the "errors" are weighted accordingly.
|
326 |
+
|
327 |
+
Parameters
|
328 |
+
----------
|
329 |
+
pred : array-like
|
330 |
+
Mask prediction
|
331 |
+
|
332 |
+
label : array-like
|
333 |
+
Mask ground truth labels
|
334 |
+
|
335 |
+
label_may : int
|
336 |
+
The label index of "may flood"
|
337 |
+
|
338 |
+
Returns
|
339 |
+
-------
|
340 |
+
may_neg_map : array-like
|
341 |
+
The map of "may" negatives
|
342 |
+
|
343 |
+
may_pos_map : array-like
|
344 |
+
The map of "may" positives
|
345 |
+
|
346 |
+
mnr : float
|
347 |
+
"May" negative rate
|
348 |
+
|
349 |
+
mpr : float
|
350 |
+
"May" positive rate
|
351 |
+
"""
|
352 |
+
may_neg_map = (1.0 - pred) * np.asarray(label == label_may, dtype=int)
|
353 |
+
may_pos_map = pred * np.asarray(label == label_may, dtype=int)
|
354 |
+
mnr = np.sum(may_neg_map) / np.sum(label == label_may)
|
355 |
+
mpr = np.sum(may_pos_map) / np.sum(label == label_may)
|
356 |
+
return may_neg_map, may_pos_map, mnr, mpr
|
357 |
+
|
358 |
+
|
359 |
+
def masker_metrics(pred, label, label_cannot=0, label_must=1):
|
360 |
+
"""
|
361 |
+
Computes a set of metrics for the masker
|
362 |
+
|
363 |
+
Parameters
|
364 |
+
----------
|
365 |
+
pred : array-like
|
366 |
+
Mask prediction
|
367 |
+
|
368 |
+
label : array-like
|
369 |
+
Mask ground truth labels
|
370 |
+
|
371 |
+
label_must : int
|
372 |
+
The label index of "must flood"
|
373 |
+
|
374 |
+
label_cannot : int
|
375 |
+
The label index of "cannot flood"
|
376 |
+
|
377 |
+
Returns
|
378 |
+
-------
|
379 |
+
tpr : float
|
380 |
+
True positive rate
|
381 |
+
|
382 |
+
tnr : float
|
383 |
+
True negative rate
|
384 |
+
|
385 |
+
precision : float
|
386 |
+
Precision, considering only cannot and must flood labels
|
387 |
+
|
388 |
+
f1 : float
|
389 |
+
F1 score, considering only cannot and must flood labels
|
390 |
+
"""
|
391 |
+
tp_map = pred * np.asarray(label == label_must, dtype=int)
|
392 |
+
tpr = np.sum(tp_map) / np.sum(label == label_must)
|
393 |
+
tn_map = (1.0 - pred) * np.asarray(label == label_cannot, dtype=int)
|
394 |
+
tnr = np.sum(tn_map) / np.sum(label == label_cannot)
|
395 |
+
fp_map = pred * np.asarray(label == label_cannot, dtype=int)
|
396 |
+
fn_map = (1.0 - pred) * np.asarray(label == label_must, dtype=int) # noqa: F841
|
397 |
+
precision = np.sum(tp_map) / (np.sum(tp_map) + np.sum(fp_map))
|
398 |
+
f1 = 2 * (precision * tpr) / (precision + tpr)
|
399 |
+
return tpr, tnr, precision, f1
|
400 |
+
|
401 |
+
|
402 |
+
def get_confusion_matrix(tpr, tnr, fpr, fnr, mpr, mnr):
|
403 |
+
"""
|
404 |
+
Constructs the confusion matrix of a masker prediction over a set of samples
|
405 |
+
|
406 |
+
Parameters
|
407 |
+
----------
|
408 |
+
tpr : vector-like
|
409 |
+
True positive rate
|
410 |
+
|
411 |
+
tnr : vector-like
|
412 |
+
True negative rate
|
413 |
+
|
414 |
+
fpr : vector-like
|
415 |
+
False positive rate
|
416 |
+
|
417 |
+
fnr : vector-like
|
418 |
+
False negative rate
|
419 |
+
|
420 |
+
mpr : vector-like
|
421 |
+
"May" positive rate
|
422 |
+
|
423 |
+
mnr : vector-like
|
424 |
+
"May" negative rate
|
425 |
+
|
426 |
+
Returns
|
427 |
+
-------
|
428 |
+
confusion_matrix : 3x3 array
|
429 |
+
Confusion matrix: [i, j] = [pred, true]
|
430 |
+
| tnr fnr mnr |
|
431 |
+
| fpr tpr mpr |
|
432 |
+
| 0. 0, 0, |
|
433 |
+
|
434 |
+
confusion_matrix_std : 3x3 array
|
435 |
+
Standard deviation of the confusion matrix
|
436 |
+
"""
|
437 |
+
# Compute mean and standard deviations over all samples
|
438 |
+
tpr_m = np.mean(tpr)
|
439 |
+
tpr_s = np.std(tpr)
|
440 |
+
tnr_m = np.mean(tnr)
|
441 |
+
tnr_s = np.std(tnr)
|
442 |
+
fpr_m = np.mean(fpr)
|
443 |
+
fpr_s = np.std(fpr)
|
444 |
+
fnr_m = np.mean(fnr)
|
445 |
+
fnr_s = np.std(fnr)
|
446 |
+
mpr_m = np.mean(mpr)
|
447 |
+
mpr_s = np.std(mpr)
|
448 |
+
mnr_m = np.mean(mnr)
|
449 |
+
mnr_s = np.std(mnr)
|
450 |
+
|
451 |
+
# Assertions
|
452 |
+
assert np.isclose(tpr_m, 1.0 - fnr_m), "TPR: {:.4f}, FNR: {:.4f}".format(
|
453 |
+
tpr_m, fnr_m
|
454 |
+
)
|
455 |
+
assert np.isclose(tnr_m, 1.0 - fpr_m), "TNR: {:.4f}, FPR: {:.4f}".format(
|
456 |
+
tnr_m, fpr_m
|
457 |
+
)
|
458 |
+
assert np.isclose(mpr_m, 1.0 - mnr_m), "MPR: {:.4f}, MNR: {:.4f}".format(
|
459 |
+
mpr_m, mnr_m
|
460 |
+
)
|
461 |
+
|
462 |
+
# Fill confusion matrix
|
463 |
+
confusion_matrix = np.zeros((3, 3))
|
464 |
+
confusion_matrix[0, 0] = tnr_m
|
465 |
+
confusion_matrix[0, 1] = fnr_m
|
466 |
+
confusion_matrix[0, 2] = mnr_m
|
467 |
+
confusion_matrix[1, 0] = fpr_m
|
468 |
+
confusion_matrix[1, 1] = tpr_m
|
469 |
+
confusion_matrix[1, 2] = mpr_m
|
470 |
+
confusion_matrix[2, 2] = 0.0
|
471 |
+
|
472 |
+
# Standard deviation
|
473 |
+
confusion_matrix_std = np.zeros((3, 3))
|
474 |
+
confusion_matrix_std[0, 0] = tnr_s
|
475 |
+
confusion_matrix_std[0, 1] = fnr_s
|
476 |
+
confusion_matrix_std[0, 2] = mnr_s
|
477 |
+
confusion_matrix_std[1, 0] = fpr_s
|
478 |
+
confusion_matrix_std[1, 1] = tpr_s
|
479 |
+
confusion_matrix_std[1, 2] = mpr_s
|
480 |
+
confusion_matrix_std[2, 2] = 0.0
|
481 |
+
return confusion_matrix, confusion_matrix_std
|
482 |
+
|
483 |
+
|
484 |
+
def edges_coherence_std_min(pred, label, label_must=1, bin_th=0.5):
|
485 |
+
"""
|
486 |
+
The standard deviation of the minimum distance between the edge of the prediction
|
487 |
+
and the edge of the "must flood" label.
|
488 |
+
|
489 |
+
Parameters
|
490 |
+
----------
|
491 |
+
pred : array-like
|
492 |
+
Mask prediction
|
493 |
+
|
494 |
+
label : array-like
|
495 |
+
Mask ground truth labels
|
496 |
+
|
497 |
+
label_must : int
|
498 |
+
The label index of "must flood"
|
499 |
+
|
500 |
+
bin_th : float
|
501 |
+
The threshold for the binarization of the prediction
|
502 |
+
|
503 |
+
Returns
|
504 |
+
-------
|
505 |
+
metric : float
|
506 |
+
The value of the metric
|
507 |
+
|
508 |
+
pred_edge : array-like
|
509 |
+
The edges images of the prediction, for visualization
|
510 |
+
|
511 |
+
label_edge : array-like
|
512 |
+
The edges images of the "must flood" label, for visualization
|
513 |
+
"""
|
514 |
+
# Keep must flood label only
|
515 |
+
label = deepcopy(label)
|
516 |
+
label[label != label_must] = -1
|
517 |
+
label[label == label_must] = 1
|
518 |
+
label[label != label_must] = 0
|
519 |
+
label = np.asarray(label, dtype=float)
|
520 |
+
|
521 |
+
# Binarize prediction
|
522 |
+
pred = np.asarray(pred > bin_th, dtype=float)
|
523 |
+
|
524 |
+
# Compute edges
|
525 |
+
pred = filters.sobel(pred)
|
526 |
+
label = filters.sobel(label)
|
527 |
+
|
528 |
+
# Location of edges
|
529 |
+
pred_coord = np.argwhere(pred > 0)
|
530 |
+
label_coord = np.argwhere(label > 0)
|
531 |
+
|
532 |
+
# Handle blank predictions
|
533 |
+
if pred_coord.shape[0] == 0:
|
534 |
+
return 1.0, pred, label
|
535 |
+
|
536 |
+
# Normalized pairwise distances between pred and label
|
537 |
+
dist_mat = np.divide(euclidean_distances(pred_coord, label_coord), pred.shape[0])
|
538 |
+
|
539 |
+
# Standard deviation of the minimum distance from pred to label
|
540 |
+
edge_coherence = np.std(np.min(dist_mat, axis=1))
|
541 |
+
|
542 |
+
return edge_coherence, pred, label
|
543 |
+
|
544 |
+
|
545 |
+
def boxplot_metric(
|
546 |
+
output_filename,
|
547 |
+
df,
|
548 |
+
metric,
|
549 |
+
dict_metrics,
|
550 |
+
do_stripplot=False,
|
551 |
+
dict_models=None,
|
552 |
+
dpi=300,
|
553 |
+
**snskwargs
|
554 |
+
):
|
555 |
+
f = plt.figure(dpi=dpi)
|
556 |
+
|
557 |
+
if do_stripplot:
|
558 |
+
ax = sns.boxplot(x="model", y=metric, data=df, fliersize=0.0, **snskwargs)
|
559 |
+
ax = sns.stripplot(
|
560 |
+
x="model", y=metric, data=df, size=2.0, color="gray", **snskwargs
|
561 |
+
)
|
562 |
+
else:
|
563 |
+
ax = sns.boxplot(x="model", y=metric, data=df, **snskwargs)
|
564 |
+
|
565 |
+
# Set axes labels
|
566 |
+
ax.set_xlabel("Models", rotation=0, fontsize="medium")
|
567 |
+
ax.set_ylabel(dict_metrics[metric], rotation=90, fontsize="medium")
|
568 |
+
|
569 |
+
# Spines
|
570 |
+
sns.despine(left=True, bottom=True)
|
571 |
+
|
572 |
+
# X-Tick labels
|
573 |
+
if dict_models:
|
574 |
+
xticklabels = [dict_models[t.get_text()] for t in ax.get_xticklabels()]
|
575 |
+
ax.set_xticklabels(
|
576 |
+
xticklabels,
|
577 |
+
rotation=20,
|
578 |
+
verticalalignment="top",
|
579 |
+
horizontalalignment="right",
|
580 |
+
fontsize="xx-small",
|
581 |
+
)
|
582 |
+
|
583 |
+
f.savefig(
|
584 |
+
output_filename,
|
585 |
+
dpi=f.dpi,
|
586 |
+
bbox_inches="tight",
|
587 |
+
facecolor="white",
|
588 |
+
transparent=False,
|
589 |
+
)
|
590 |
+
f.clear()
|
591 |
+
plt.close(f)
|
592 |
+
|
593 |
+
|
594 |
+
def clustermap_metric(
|
595 |
+
output_filename,
|
596 |
+
df,
|
597 |
+
metric,
|
598 |
+
dict_metrics,
|
599 |
+
method="average",
|
600 |
+
cluster_metric="euclidean",
|
601 |
+
dict_models=None,
|
602 |
+
dpi=300,
|
603 |
+
**snskwargs
|
604 |
+
):
|
605 |
+
ax_grid = sns.clustermap(data=df, method=method, metric=cluster_metric, **snskwargs)
|
606 |
+
ax_heatmap = ax_grid.ax_heatmap
|
607 |
+
ax_cbar = ax_grid.ax_cbar
|
608 |
+
|
609 |
+
# Set axes labels
|
610 |
+
ax_heatmap.set_xlabel("Models", rotation=0, fontsize="medium")
|
611 |
+
ax_heatmap.set_ylabel("Images", rotation=90, fontsize="medium")
|
612 |
+
|
613 |
+
# Set title
|
614 |
+
ax_cbar.set_title(dict_metrics[metric], rotation=0, fontsize="x-large")
|
615 |
+
|
616 |
+
# X-Tick labels
|
617 |
+
if dict_models:
|
618 |
+
xticklabels = [dict_models[t.get_text()] for t in ax_heatmap.get_xticklabels()]
|
619 |
+
ax_heatmap.set_xticklabels(
|
620 |
+
xticklabels,
|
621 |
+
rotation=20,
|
622 |
+
verticalalignment="top",
|
623 |
+
horizontalalignment="right",
|
624 |
+
fontsize="small",
|
625 |
+
)
|
626 |
+
|
627 |
+
ax_grid.fig.savefig(
|
628 |
+
output_filename,
|
629 |
+
dpi=dpi,
|
630 |
+
bbox_inches="tight",
|
631 |
+
facecolor="white",
|
632 |
+
transparent=False,
|
633 |
+
)
|
634 |
+
ax_grid.fig.clear()
|
635 |
+
plt.close(ax_grid.fig)
|
climategan/fid.py
ADDED
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://github.com/mseitzer/pytorch-fid
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchvision
|
8 |
+
from scipy import linalg
|
9 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
10 |
+
|
11 |
+
try:
|
12 |
+
from torchvision.models.utils import load_state_dict_from_url
|
13 |
+
except ImportError:
|
14 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
15 |
+
|
16 |
+
FID_WEIGHTS_URL = (
|
17 |
+
"https://github.com/mseitzer/pytorch-fid/releases/download/"
|
18 |
+
+ "fid_weights/pt_inception-2015-12-05-6726825d.pth"
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class InceptionV3(nn.Module):
|
23 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
24 |
+
|
25 |
+
# Index of default block of inception to return,
|
26 |
+
# corresponds to output of final average pooling
|
27 |
+
DEFAULT_BLOCK_INDEX = 3
|
28 |
+
|
29 |
+
# Maps feature dimensionality to their output blocks indices
|
30 |
+
BLOCK_INDEX_BY_DIM = {
|
31 |
+
64: 0, # First max pooling features
|
32 |
+
192: 1, # Second max pooling features
|
33 |
+
768: 2, # Pre-aux classifier features
|
34 |
+
2048: 3, # Final average pooling features
|
35 |
+
}
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
40 |
+
resize_input=True,
|
41 |
+
normalize_input=True,
|
42 |
+
requires_grad=False,
|
43 |
+
use_fid_inception=True,
|
44 |
+
):
|
45 |
+
"""Build pretrained InceptionV3
|
46 |
+
Parameters
|
47 |
+
----------
|
48 |
+
output_blocks : list of int
|
49 |
+
Indices of blocks to return features of. Possible values are:
|
50 |
+
- 0: corresponds to output of first max pooling
|
51 |
+
- 1: corresponds to output of second max pooling
|
52 |
+
- 2: corresponds to output which is fed to aux classifier
|
53 |
+
- 3: corresponds to output of final average pooling
|
54 |
+
resize_input : bool
|
55 |
+
If true, bilinearly resizes input to width and height 299 before
|
56 |
+
feeding input to model. As the network without fully connected
|
57 |
+
layers is fully convolutional, it should be able to handle inputs
|
58 |
+
of arbitrary size, so resizing might not be strictly needed
|
59 |
+
normalize_input : bool
|
60 |
+
If true, scales the input from range (0, 1) to the range the
|
61 |
+
pretrained Inception network expects, namely (-1, 1)
|
62 |
+
requires_grad : bool
|
63 |
+
If true, parameters of the model require gradients. Possibly useful
|
64 |
+
for finetuning the network
|
65 |
+
use_fid_inception : bool
|
66 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
67 |
+
FID implementation. If false, uses the pretrained Inception model
|
68 |
+
available in torchvision. The FID Inception model has different
|
69 |
+
weights and a slightly different structure from torchvision's
|
70 |
+
Inception model. If you want to compute FID scores, you are
|
71 |
+
strongly advised to set this parameter to true to get comparable
|
72 |
+
results.
|
73 |
+
"""
|
74 |
+
super(InceptionV3, self).__init__()
|
75 |
+
|
76 |
+
self.resize_input = resize_input
|
77 |
+
self.normalize_input = normalize_input
|
78 |
+
self.output_blocks = sorted(output_blocks)
|
79 |
+
self.last_needed_block = max(output_blocks)
|
80 |
+
|
81 |
+
assert self.last_needed_block <= 3, "Last possible output block index is 3"
|
82 |
+
|
83 |
+
self.blocks = nn.ModuleList()
|
84 |
+
|
85 |
+
if use_fid_inception:
|
86 |
+
inception = fid_inception_v3()
|
87 |
+
else:
|
88 |
+
inception = _inception_v3(pretrained=True)
|
89 |
+
|
90 |
+
# Block 0: input to maxpool1
|
91 |
+
block0 = [
|
92 |
+
inception.Conv2d_1a_3x3,
|
93 |
+
inception.Conv2d_2a_3x3,
|
94 |
+
inception.Conv2d_2b_3x3,
|
95 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
96 |
+
]
|
97 |
+
self.blocks.append(nn.Sequential(*block0))
|
98 |
+
|
99 |
+
# Block 1: maxpool1 to maxpool2
|
100 |
+
if self.last_needed_block >= 1:
|
101 |
+
block1 = [
|
102 |
+
inception.Conv2d_3b_1x1,
|
103 |
+
inception.Conv2d_4a_3x3,
|
104 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
105 |
+
]
|
106 |
+
self.blocks.append(nn.Sequential(*block1))
|
107 |
+
|
108 |
+
# Block 2: maxpool2 to aux classifier
|
109 |
+
if self.last_needed_block >= 2:
|
110 |
+
block2 = [
|
111 |
+
inception.Mixed_5b,
|
112 |
+
inception.Mixed_5c,
|
113 |
+
inception.Mixed_5d,
|
114 |
+
inception.Mixed_6a,
|
115 |
+
inception.Mixed_6b,
|
116 |
+
inception.Mixed_6c,
|
117 |
+
inception.Mixed_6d,
|
118 |
+
inception.Mixed_6e,
|
119 |
+
]
|
120 |
+
self.blocks.append(nn.Sequential(*block2))
|
121 |
+
|
122 |
+
# Block 3: aux classifier to final avgpool
|
123 |
+
if self.last_needed_block >= 3:
|
124 |
+
block3 = [
|
125 |
+
inception.Mixed_7a,
|
126 |
+
inception.Mixed_7b,
|
127 |
+
inception.Mixed_7c,
|
128 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
129 |
+
]
|
130 |
+
self.blocks.append(nn.Sequential(*block3))
|
131 |
+
|
132 |
+
for param in self.parameters():
|
133 |
+
param.requires_grad = requires_grad
|
134 |
+
|
135 |
+
def forward(self, inp):
|
136 |
+
"""Get Inception feature maps
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
inp : torch.autograd.Variable
|
140 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
141 |
+
range (0, 1)
|
142 |
+
Returns
|
143 |
+
-------
|
144 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
145 |
+
block, sorted ascending by index
|
146 |
+
"""
|
147 |
+
outp = []
|
148 |
+
x = inp
|
149 |
+
|
150 |
+
if self.resize_input:
|
151 |
+
x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
|
152 |
+
|
153 |
+
if self.normalize_input:
|
154 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
155 |
+
|
156 |
+
for idx, block in enumerate(self.blocks):
|
157 |
+
x = block(x)
|
158 |
+
if idx in self.output_blocks:
|
159 |
+
outp.append(x)
|
160 |
+
|
161 |
+
if idx == self.last_needed_block:
|
162 |
+
break
|
163 |
+
|
164 |
+
return outp
|
165 |
+
|
166 |
+
|
167 |
+
def _inception_v3(*args, **kwargs):
|
168 |
+
"""Wraps `torchvision.models.inception_v3`
|
169 |
+
Skips default weight initialization if supported by torchvision version.
|
170 |
+
See https://github.com/mseitzer/pytorch-fid/issues/28.
|
171 |
+
"""
|
172 |
+
try:
|
173 |
+
version = tuple(map(int, torchvision.__version__.split(".")[:2]))
|
174 |
+
except ValueError:
|
175 |
+
# Just a caution against weird version strings
|
176 |
+
version = (0,)
|
177 |
+
|
178 |
+
if version >= (0, 6):
|
179 |
+
kwargs["init_weights"] = False
|
180 |
+
|
181 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
182 |
+
|
183 |
+
|
184 |
+
def fid_inception_v3():
|
185 |
+
"""Build pretrained Inception model for FID computation
|
186 |
+
The Inception model for FID computation uses a different set of weights
|
187 |
+
and has a slightly different structure than torchvision's Inception.
|
188 |
+
This method first constructs torchvision's Inception and then patches the
|
189 |
+
necessary parts that are different in the FID Inception model.
|
190 |
+
"""
|
191 |
+
inception = _inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
|
192 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
193 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
194 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
195 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
196 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
197 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
198 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
199 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
200 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
201 |
+
|
202 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
203 |
+
inception.load_state_dict(state_dict)
|
204 |
+
return inception
|
205 |
+
|
206 |
+
|
207 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
208 |
+
"""InceptionA block patched for FID computation"""
|
209 |
+
|
210 |
+
def __init__(self, in_channels, pool_features):
|
211 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
branch1x1 = self.branch1x1(x)
|
215 |
+
|
216 |
+
branch5x5 = self.branch5x5_1(x)
|
217 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
218 |
+
|
219 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
220 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
221 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
222 |
+
|
223 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
224 |
+
# its average calculation
|
225 |
+
branch_pool = F.avg_pool2d(
|
226 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
227 |
+
)
|
228 |
+
branch_pool = self.branch_pool(branch_pool)
|
229 |
+
|
230 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
231 |
+
return torch.cat(outputs, 1)
|
232 |
+
|
233 |
+
|
234 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
235 |
+
"""InceptionC block patched for FID computation"""
|
236 |
+
|
237 |
+
def __init__(self, in_channels, channels_7x7):
|
238 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
branch1x1 = self.branch1x1(x)
|
242 |
+
|
243 |
+
branch7x7 = self.branch7x7_1(x)
|
244 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
245 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
246 |
+
|
247 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
248 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
249 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
250 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
251 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
252 |
+
|
253 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
254 |
+
# its average calculation
|
255 |
+
branch_pool = F.avg_pool2d(
|
256 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
257 |
+
)
|
258 |
+
branch_pool = self.branch_pool(branch_pool)
|
259 |
+
|
260 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
261 |
+
return torch.cat(outputs, 1)
|
262 |
+
|
263 |
+
|
264 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
265 |
+
"""First InceptionE block patched for FID computation"""
|
266 |
+
|
267 |
+
def __init__(self, in_channels):
|
268 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
269 |
+
|
270 |
+
def forward(self, x):
|
271 |
+
branch1x1 = self.branch1x1(x)
|
272 |
+
|
273 |
+
branch3x3 = self.branch3x3_1(x)
|
274 |
+
branch3x3 = [
|
275 |
+
self.branch3x3_2a(branch3x3),
|
276 |
+
self.branch3x3_2b(branch3x3),
|
277 |
+
]
|
278 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
279 |
+
|
280 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
281 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
282 |
+
branch3x3dbl = [
|
283 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
284 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
285 |
+
]
|
286 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
287 |
+
|
288 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
289 |
+
# its average calculation
|
290 |
+
branch_pool = F.avg_pool2d(
|
291 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
292 |
+
)
|
293 |
+
branch_pool = self.branch_pool(branch_pool)
|
294 |
+
|
295 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
296 |
+
return torch.cat(outputs, 1)
|
297 |
+
|
298 |
+
|
299 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
300 |
+
"""Second InceptionE block patched for FID computation"""
|
301 |
+
|
302 |
+
def __init__(self, in_channels):
|
303 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
304 |
+
|
305 |
+
def forward(self, x):
|
306 |
+
branch1x1 = self.branch1x1(x)
|
307 |
+
|
308 |
+
branch3x3 = self.branch3x3_1(x)
|
309 |
+
branch3x3 = [
|
310 |
+
self.branch3x3_2a(branch3x3),
|
311 |
+
self.branch3x3_2b(branch3x3),
|
312 |
+
]
|
313 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
314 |
+
|
315 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
316 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
317 |
+
branch3x3dbl = [
|
318 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
319 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
320 |
+
]
|
321 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
322 |
+
|
323 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
324 |
+
# pooling. This is likely an error in this specific Inception
|
325 |
+
# implementation, as other Inception models use average pooling here
|
326 |
+
# (which matches the description in the paper).
|
327 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
328 |
+
branch_pool = self.branch_pool(branch_pool)
|
329 |
+
|
330 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
331 |
+
return torch.cat(outputs, 1)
|
332 |
+
|
333 |
+
|
334 |
+
def compute_val_fid(trainer, verbose=0):
|
335 |
+
"""
|
336 |
+
Compute the fid score between the n=opts.train.fid.n_images real images
|
337 |
+
from the validation set (domain is rf) and n fake images pained from
|
338 |
+
those n validation images
|
339 |
+
|
340 |
+
Args:
|
341 |
+
trainer (climategan.Trainer): trainer to compute the val fid for
|
342 |
+
|
343 |
+
Returns:
|
344 |
+
float: FID score
|
345 |
+
"""
|
346 |
+
# get opts params
|
347 |
+
batch_size = trainer.opts.train.fid.get("batch_size", 50)
|
348 |
+
dims = trainer.opts.train.fid.get("dims", 2048)
|
349 |
+
|
350 |
+
# set inception model
|
351 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
352 |
+
model = InceptionV3([block_idx]).to(trainer.device)
|
353 |
+
|
354 |
+
# first fid computation: compute the real stats, only once
|
355 |
+
if trainer.real_val_fid_stats is None:
|
356 |
+
if verbose > 0:
|
357 |
+
print("Computing real_val_fid_stats for the first time")
|
358 |
+
set_real_val_fid_stats(trainer, model, batch_size, dims)
|
359 |
+
|
360 |
+
# get real stats
|
361 |
+
real_m = trainer.real_val_fid_stats["m"]
|
362 |
+
real_s = trainer.real_val_fid_stats["s"]
|
363 |
+
|
364 |
+
# compute fake images
|
365 |
+
fakes = compute_fakes(trainer)
|
366 |
+
if verbose > 0:
|
367 |
+
print("Computing fake activation statistics")
|
368 |
+
# get fake stats
|
369 |
+
fake_m, fake_s = calculate_activation_statistics(
|
370 |
+
fakes, model, batch_size=batch_size, dims=dims, device=trainer.device
|
371 |
+
)
|
372 |
+
# compute FD between the real and the fake inception stats
|
373 |
+
return calculate_frechet_distance(real_m, real_s, fake_m, fake_s)
|
374 |
+
|
375 |
+
|
376 |
+
def set_real_val_fid_stats(trainer, model, batch_size, dims):
|
377 |
+
"""
|
378 |
+
Sets the real_val_fid_stats attribute of the trainer with the m and
|
379 |
+
s outputs of calculate_activation_statistics on the real data.
|
380 |
+
|
381 |
+
This needs to be done only once since nothing changes during training here.
|
382 |
+
|
383 |
+
Args:
|
384 |
+
trainer (climategan.Trainer): trainer instance to compute the stats for
|
385 |
+
model (InceptionV3): inception model to get the activations from
|
386 |
+
batch_size (int): inception inference batch size
|
387 |
+
dims (int): dimension selected in the model
|
388 |
+
"""
|
389 |
+
# in the rf domain display_size may be different from fid.n_images
|
390 |
+
limit = trainer.opts.train.fid.n_images
|
391 |
+
display_x = torch.stack(
|
392 |
+
[sample["data"]["x"] for sample in trainer.display_images["val"]["rf"][:limit]]
|
393 |
+
).to(trainer.device)
|
394 |
+
m, s = calculate_activation_statistics(
|
395 |
+
display_x, model, batch_size=batch_size, dims=dims, device=trainer.device
|
396 |
+
)
|
397 |
+
trainer.real_val_fid_stats = {"m": m, "s": s}
|
398 |
+
|
399 |
+
|
400 |
+
def compute_fakes(trainer, verbose=0):
|
401 |
+
"""
|
402 |
+
Compute current fake inferences
|
403 |
+
|
404 |
+
Args:
|
405 |
+
trainer (climategan.Trainer): trainer instance
|
406 |
+
verbose (int, optional): Print level. Defaults to 0.
|
407 |
+
|
408 |
+
Returns:
|
409 |
+
torch.Tensor: trainer.opts.train.fid.n_images painted images
|
410 |
+
"""
|
411 |
+
# in the rf domain display_size may be different from fid.n_images
|
412 |
+
n = trainer.opts.train.fid.n_images
|
413 |
+
bs = trainer.opts.data.loaders.batch_size
|
414 |
+
|
415 |
+
display_batches = [
|
416 |
+
(sample["data"]["x"], sample["data"]["m"])
|
417 |
+
for sample in trainer.display_images["val"]["rf"][:n]
|
418 |
+
]
|
419 |
+
|
420 |
+
display_x = torch.stack([b[0] for b in display_batches]).to(trainer.device)
|
421 |
+
display_m = torch.stack([b[0] for b in display_batches]).to(trainer.device)
|
422 |
+
nbs = len(display_x) // bs + 1
|
423 |
+
|
424 |
+
fakes = []
|
425 |
+
for b in range(nbs):
|
426 |
+
if verbose > 0:
|
427 |
+
print("computing fakes {}/{}".format(b + 1, nbs), end="\r", flush=True)
|
428 |
+
with torch.no_grad():
|
429 |
+
x = display_x[b * bs : (b + 1) * bs]
|
430 |
+
m = display_m[b * bs : (b + 1) * bs]
|
431 |
+
fake = trainer.G.paint(m, x)
|
432 |
+
fakes.append(fake)
|
433 |
+
|
434 |
+
return torch.cat(fakes, dim=0)
|
435 |
+
|
436 |
+
|
437 |
+
def calculate_activation_statistics(
|
438 |
+
images, model, batch_size=50, dims=2048, device="cpu"
|
439 |
+
):
|
440 |
+
"""Calculation of the statistics used by the FID.
|
441 |
+
Params:
|
442 |
+
-- images : List of images
|
443 |
+
-- model : Instance of inception model
|
444 |
+
-- batch_size : The images numpy array is split into batches with
|
445 |
+
batch size batch_size. A reasonable batch size
|
446 |
+
depends on the hardware.
|
447 |
+
-- dims : Dimensionality of features returned by Inception
|
448 |
+
-- device : Device to run calculations
|
449 |
+
Returns:
|
450 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
451 |
+
the inception model.
|
452 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
453 |
+
the inception model.
|
454 |
+
"""
|
455 |
+
act = get_activations(images, model, batch_size, dims, device)
|
456 |
+
mu = np.mean(act, axis=0)
|
457 |
+
sigma = np.cov(act, rowvar=False)
|
458 |
+
return mu, sigma
|
459 |
+
|
460 |
+
|
461 |
+
def get_activations(images, model, batch_size=50, dims=2048, device="cpu"):
|
462 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
463 |
+
Params:
|
464 |
+
-- images : List of images
|
465 |
+
-- model : Instance of inception model
|
466 |
+
-- batch_size : Batch size of images for the model to process at once.
|
467 |
+
Make sure that the number of samples is a multiple of
|
468 |
+
the batch size, otherwise some samples are ignored. This
|
469 |
+
behavior is retained to match the original FID score
|
470 |
+
implementation.
|
471 |
+
-- dims : Dimensionality of features returned by Inception
|
472 |
+
-- device : Device to run calculations
|
473 |
+
Returns:
|
474 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
475 |
+
activations of the given tensor when feeding inception with the
|
476 |
+
query tensor.
|
477 |
+
"""
|
478 |
+
model.eval()
|
479 |
+
|
480 |
+
pred_arr = np.empty((len(images), dims))
|
481 |
+
|
482 |
+
start_idx = 0
|
483 |
+
nbs = len(images) // batch_size + 1
|
484 |
+
|
485 |
+
for b in range(nbs):
|
486 |
+
batch = images[b * batch_size : (b + 1) * batch_size].to(device)
|
487 |
+
if not batch.nelement():
|
488 |
+
continue
|
489 |
+
|
490 |
+
with torch.no_grad():
|
491 |
+
pred = model(batch)[0]
|
492 |
+
|
493 |
+
# If model output is not scalar, apply global spatial average pooling.
|
494 |
+
# This happens if you choose a dimensionality not equal 2048.
|
495 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
496 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
497 |
+
|
498 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
499 |
+
|
500 |
+
pred_arr[start_idx : start_idx + pred.shape[0]] = pred
|
501 |
+
|
502 |
+
start_idx = start_idx + pred.shape[0]
|
503 |
+
|
504 |
+
return pred_arr
|
505 |
+
|
506 |
+
|
507 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
508 |
+
"""Numpy implementation of the Frechet Distance.
|
509 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
510 |
+
and X_2 ~ N(mu_2, C_2) is
|
511 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
512 |
+
Stable version by Dougal J. Sutherland.
|
513 |
+
Params:
|
514 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
515 |
+
inception net (like returned by the function 'get_predictions')
|
516 |
+
for generated samples.
|
517 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
518 |
+
representative data set.
|
519 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
520 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
521 |
+
representative data set.
|
522 |
+
Returns:
|
523 |
+
-- : The Frechet Distance.
|
524 |
+
"""
|
525 |
+
|
526 |
+
mu1 = np.atleast_1d(mu1)
|
527 |
+
mu2 = np.atleast_1d(mu2)
|
528 |
+
|
529 |
+
sigma1 = np.atleast_2d(sigma1)
|
530 |
+
sigma2 = np.atleast_2d(sigma2)
|
531 |
+
|
532 |
+
assert (
|
533 |
+
mu1.shape == mu2.shape
|
534 |
+
), "Training and test mean vectors have different lengths"
|
535 |
+
assert (
|
536 |
+
sigma1.shape == sigma2.shape
|
537 |
+
), "Training and test covariances have different dimensions"
|
538 |
+
|
539 |
+
diff = mu1 - mu2
|
540 |
+
|
541 |
+
# Product might be almost singular
|
542 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
543 |
+
if not np.isfinite(covmean).all():
|
544 |
+
msg = (
|
545 |
+
"fid calculation produces singular product; "
|
546 |
+
"adding %s to diagonal of cov estimates"
|
547 |
+
) % eps
|
548 |
+
print(msg)
|
549 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
550 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
551 |
+
|
552 |
+
# Numerical error might give slight imaginary component
|
553 |
+
if np.iscomplexobj(covmean):
|
554 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
555 |
+
m = np.max(np.abs(covmean.imag))
|
556 |
+
raise ValueError("Imaginary component {}".format(m))
|
557 |
+
covmean = covmean.real
|
558 |
+
|
559 |
+
tr_covmean = np.trace(covmean)
|
560 |
+
|
561 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
climategan/fire.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import random
|
4 |
+
import kornia
|
5 |
+
from torchvision.transforms.functional import adjust_brightness, adjust_contrast
|
6 |
+
|
7 |
+
from climategan.tutils import normalize, retrieve_sky_mask
|
8 |
+
|
9 |
+
try:
|
10 |
+
from kornia.filters import filter2d
|
11 |
+
except ImportError:
|
12 |
+
from kornia.filters import filter2D as filter2d
|
13 |
+
|
14 |
+
|
15 |
+
def increase_sky_mask(mask, p_w=0, p_h=0):
|
16 |
+
"""
|
17 |
+
Increases sky mask in width and height by a given pourcentage
|
18 |
+
(Purpose: when applying Gaussian blur, there are no artifacts of blue sky behind)
|
19 |
+
Args:
|
20 |
+
sky_mask (torch.Tensor): Sky mask of shape (H,W)
|
21 |
+
p_w (float): Percentage of mask width by which to increase
|
22 |
+
the width of the sky region
|
23 |
+
p_h (float): Percentage of mask height by which to increase
|
24 |
+
the height of the sky region
|
25 |
+
Returns:
|
26 |
+
torch.Tensor: Sky mask increased given p_w and p_h
|
27 |
+
"""
|
28 |
+
|
29 |
+
if p_h <= 0 and p_w <= 0:
|
30 |
+
return mask
|
31 |
+
|
32 |
+
n_lines = int(p_h * mask.shape[-2])
|
33 |
+
n_cols = int(p_w * mask.shape[-1])
|
34 |
+
|
35 |
+
temp_mask = mask.clone().detach()
|
36 |
+
for i in range(1, n_cols):
|
37 |
+
temp_mask[:, :, :, i::] += mask[:, :, :, 0:-i]
|
38 |
+
temp_mask[:, :, :, 0:-i] += mask[:, :, :, i::]
|
39 |
+
|
40 |
+
new_mask = temp_mask.clone().detach()
|
41 |
+
for i in range(1, n_lines):
|
42 |
+
new_mask[:, :, i::, :] += temp_mask[:, :, 0:-i, :]
|
43 |
+
new_mask[:, :, 0:-i, :] += temp_mask[:, :, i::, :]
|
44 |
+
|
45 |
+
new_mask[new_mask >= 1] = 1
|
46 |
+
|
47 |
+
return new_mask
|
48 |
+
|
49 |
+
|
50 |
+
def paste_filter(x, filter_, mask):
|
51 |
+
"""
|
52 |
+
Pastes a filter over an image given a mask
|
53 |
+
Where the mask is 1, the filter is copied as is.
|
54 |
+
Where the mask is 0, the current value is preserved.
|
55 |
+
Intermediate values will mix the two images together.
|
56 |
+
Args:
|
57 |
+
x (torch.Tensor): Input tensor, range must be [0, 255]
|
58 |
+
filer_ (torch.Tensor): Filter, range must be [0, 255]
|
59 |
+
mask (torch.Tensor): Mask, range must be [0, 1]
|
60 |
+
Returns:
|
61 |
+
torch.Tensor: New tensor with filter pasted on it
|
62 |
+
"""
|
63 |
+
assert len(x.shape) == len(filter_.shape) == len(mask.shape)
|
64 |
+
x = filter_ * mask + x * (1 - mask)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
def add_fire(x, seg_preds, fire_opts):
|
69 |
+
"""
|
70 |
+
Transforms input tensor given wildfires event
|
71 |
+
Args:
|
72 |
+
x (torch.Tensor): Input tensor
|
73 |
+
seg_preds (torch.Tensor): Semantic segmentation predictions for input tensor
|
74 |
+
filter_color (tuple): (r,g,b) tuple for the color of the sky
|
75 |
+
blur_radius (float): radius of the Gaussian blur that smooths
|
76 |
+
the transition between sky and foreground
|
77 |
+
Returns:
|
78 |
+
torch.Tensor: Wildfire version of input tensor
|
79 |
+
"""
|
80 |
+
wildfire_tens = normalize(x, 0, 255)
|
81 |
+
|
82 |
+
# Warm the image
|
83 |
+
wildfire_tens[:, 2, :, :] -= 20
|
84 |
+
wildfire_tens[:, 1, :, :] -= 10
|
85 |
+
wildfire_tens[:, 0, :, :] += 40
|
86 |
+
wildfire_tens.clamp_(0, 255)
|
87 |
+
wildfire_tens = wildfire_tens.to(torch.uint8)
|
88 |
+
|
89 |
+
# Darken the picture and increase contrast
|
90 |
+
wildfire_tens = adjust_contrast(wildfire_tens, contrast_factor=1.5)
|
91 |
+
wildfire_tens = adjust_brightness(wildfire_tens, brightness_factor=0.73)
|
92 |
+
|
93 |
+
sky_mask = retrieve_sky_mask(seg_preds).unsqueeze(1)
|
94 |
+
|
95 |
+
if fire_opts.get("crop_bottom_sky_mask"):
|
96 |
+
i = 2 * sky_mask.shape[-2] // 3
|
97 |
+
sky_mask[..., i:, :] = 0
|
98 |
+
|
99 |
+
sky_mask = F.interpolate(
|
100 |
+
sky_mask.to(torch.float),
|
101 |
+
(wildfire_tens.shape[-2], wildfire_tens.shape[-1]),
|
102 |
+
)
|
103 |
+
sky_mask = increase_sky_mask(sky_mask, 0.18, 0.18)
|
104 |
+
|
105 |
+
kernel_size = (fire_opts.get("kernel_size", 301), fire_opts.get("kernel_size", 301))
|
106 |
+
sigma = (fire_opts.get("kernel_sigma", 150.5), fire_opts.get("kernel_sigma", 150.5))
|
107 |
+
border_type = "reflect"
|
108 |
+
kernel = torch.unsqueeze(
|
109 |
+
kornia.filters.kernels.get_gaussian_kernel2d(kernel_size, sigma), dim=0
|
110 |
+
).to(x.device)
|
111 |
+
sky_mask = filter2d(sky_mask, kernel, border_type)
|
112 |
+
|
113 |
+
filter_ = torch.ones(wildfire_tens.shape, device=x.device)
|
114 |
+
filter_[:, 0, :, :] = 255
|
115 |
+
filter_[:, 1, :, :] = random.randint(100, 150)
|
116 |
+
filter_[:, 2, :, :] = 0
|
117 |
+
|
118 |
+
wildfire_tens = paste_tensor(wildfire_tens, filter_, sky_mask, 200)
|
119 |
+
|
120 |
+
wildfire_tens = adjust_brightness(wildfire_tens.to(torch.uint8), 0.8)
|
121 |
+
wildfire_tens = wildfire_tens.to(torch.float)
|
122 |
+
|
123 |
+
# dummy pixels to fool scaling and preserve range
|
124 |
+
wildfire_tens[:, :, 0, 0] = 255.0
|
125 |
+
wildfire_tens[:, :, -1, -1] = 0.0
|
126 |
+
|
127 |
+
return wildfire_tens
|
128 |
+
|
129 |
+
|
130 |
+
def paste_tensor(source, filter_, mask, transparency):
|
131 |
+
mask = transparency / 255.0 * mask
|
132 |
+
new = mask * filter_ + (1.0 - mask) * source
|
133 |
+
return new
|
climategan/generator.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Complete Generator architecture:
|
2 |
+
* OmniGenerator
|
3 |
+
* Encoder
|
4 |
+
* Decoders
|
5 |
+
"""
|
6 |
+
from pathlib import Path
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import yaml
|
13 |
+
from addict import Dict
|
14 |
+
from torch import softmax
|
15 |
+
|
16 |
+
import climategan.strings as strings
|
17 |
+
from climategan.deeplab import create_encoder, create_segmentation_decoder
|
18 |
+
from climategan.depth import create_depth_decoder
|
19 |
+
from climategan.masker import create_mask_decoder
|
20 |
+
from climategan.painter import create_painter
|
21 |
+
from climategan.tutils import init_weights, mix_noise, normalize
|
22 |
+
|
23 |
+
|
24 |
+
def create_generator(opts, device="cpu", latent_shape=None, no_init=False, verbose=0):
|
25 |
+
G = OmniGenerator(opts, latent_shape, verbose, no_init)
|
26 |
+
if no_init:
|
27 |
+
print("Sending to", device)
|
28 |
+
return G.to(device)
|
29 |
+
|
30 |
+
for model in G.decoders:
|
31 |
+
net = G.decoders[model]
|
32 |
+
if model == "s":
|
33 |
+
continue
|
34 |
+
if isinstance(net, nn.ModuleDict):
|
35 |
+
for domain, domain_model in net.items():
|
36 |
+
init_weights(
|
37 |
+
net[domain_model],
|
38 |
+
init_type=opts.gen[model].init_type,
|
39 |
+
init_gain=opts.gen[model].init_gain,
|
40 |
+
verbose=verbose,
|
41 |
+
caller=f"create_generator decoder {model} {domain}",
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
init_weights(
|
45 |
+
G.decoders[model],
|
46 |
+
init_type=opts.gen[model].init_type,
|
47 |
+
init_gain=opts.gen[model].init_gain,
|
48 |
+
verbose=verbose,
|
49 |
+
caller=f"create_generator decoder {model}",
|
50 |
+
)
|
51 |
+
if G.encoder is not None and opts.gen.encoder.architecture == "base":
|
52 |
+
init_weights(
|
53 |
+
G.encoder,
|
54 |
+
init_type=opts.gen.encoder.init_type,
|
55 |
+
init_gain=opts.gen.encoder.init_gain,
|
56 |
+
verbose=verbose,
|
57 |
+
caller="create_generator encoder",
|
58 |
+
)
|
59 |
+
|
60 |
+
print("Sending to", device)
|
61 |
+
return G.to(device)
|
62 |
+
|
63 |
+
|
64 |
+
class OmniGenerator(nn.Module):
|
65 |
+
def __init__(self, opts, latent_shape=None, verbose=0, no_init=False):
|
66 |
+
"""Creates the generator. All decoders listed in opts.gen will be added
|
67 |
+
to the Generator.decoders ModuleDict if opts.gen.DecoderInitial is not True.
|
68 |
+
Then can be accessed as G.decoders.T or G.decoders["T"] for instance,
|
69 |
+
for the image Translation decoder
|
70 |
+
|
71 |
+
Args:
|
72 |
+
opts (addict.Dict): configuration dict
|
73 |
+
"""
|
74 |
+
super().__init__()
|
75 |
+
self.opts = opts
|
76 |
+
self.verbose = verbose
|
77 |
+
self.encoder = None
|
78 |
+
if any(t in opts.tasks for t in "msd"):
|
79 |
+
self.encoder = create_encoder(opts, no_init, verbose)
|
80 |
+
|
81 |
+
self.decoders = {}
|
82 |
+
self.painter = nn.Module()
|
83 |
+
|
84 |
+
if "d" in opts.tasks:
|
85 |
+
self.decoders["d"] = create_depth_decoder(opts, no_init, verbose)
|
86 |
+
|
87 |
+
if self.verbose > 0:
|
88 |
+
print(f" - Add {self.decoders['d'].__class__.__name__}")
|
89 |
+
|
90 |
+
if "s" in opts.tasks:
|
91 |
+
self.decoders["s"] = create_segmentation_decoder(opts, no_init, verbose)
|
92 |
+
|
93 |
+
if "m" in opts.tasks:
|
94 |
+
self.decoders["m"] = create_mask_decoder(opts, no_init, verbose)
|
95 |
+
|
96 |
+
self.decoders = nn.ModuleDict(self.decoders)
|
97 |
+
|
98 |
+
if "p" in self.opts.tasks:
|
99 |
+
self.painter = create_painter(opts, no_init, verbose)
|
100 |
+
else:
|
101 |
+
if self.verbose > 0:
|
102 |
+
print(" - Add Empty Painter")
|
103 |
+
|
104 |
+
def __str__(self):
|
105 |
+
return strings.generator(self)
|
106 |
+
|
107 |
+
def encode(self, x):
|
108 |
+
"""
|
109 |
+
Forward x through the encoder
|
110 |
+
|
111 |
+
Args:
|
112 |
+
x (torch.Tensor): B3HW input tensor
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
list: High and Low level features from the encoder
|
116 |
+
"""
|
117 |
+
assert self.encoder is not None
|
118 |
+
return self.encoder.forward(x)
|
119 |
+
|
120 |
+
def decode(self, x=None, z=None, return_z=False, return_z_depth=False):
|
121 |
+
"""
|
122 |
+
Comptutes the predictions of all available decoders from either x or z.
|
123 |
+
If using spade for the masker with 15 channels, x *must* be provided,
|
124 |
+
whether z is too or not.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
x (torch.Tensor, optional): Input tensor (B3HW). Defaults to None.
|
128 |
+
z (list, optional): List of high and low-level features as BCHW.
|
129 |
+
Defaults to None.
|
130 |
+
return_z (bool, optional): whether or not to return z in the dict.
|
131 |
+
Defaults to False.
|
132 |
+
return_z_depth (bool, optional): whether or not to return z_depth
|
133 |
+
in the dict. Defaults to False.
|
134 |
+
|
135 |
+
Raises:
|
136 |
+
ValueError: If using spade for the masker with 15 channels but x is None
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
dict: {task: prediction_tensor} (may include z and z_depth
|
140 |
+
depending on args)
|
141 |
+
"""
|
142 |
+
|
143 |
+
assert x is not None or z is not None
|
144 |
+
if self.opts.gen.m.use_spade and self.opts.m.spade.cond_nc == 15:
|
145 |
+
if x is None:
|
146 |
+
raise ValueError(
|
147 |
+
"When using spade for the Masker with 15 channels,"
|
148 |
+
+ " x MUST be provided"
|
149 |
+
)
|
150 |
+
|
151 |
+
z_depth = cond = d = s = None
|
152 |
+
out = {}
|
153 |
+
|
154 |
+
if z is None:
|
155 |
+
z = self.encode(x)
|
156 |
+
|
157 |
+
if return_z:
|
158 |
+
out["z"] = z
|
159 |
+
|
160 |
+
if "d" in self.decoders:
|
161 |
+
d, z_depth = self.decoders["d"](z)
|
162 |
+
out["d"] = d
|
163 |
+
|
164 |
+
if return_z_depth:
|
165 |
+
out["z_depth"] = z_depth
|
166 |
+
|
167 |
+
if "s" in self.decoders:
|
168 |
+
s = self.decoders["s"](z, z_depth)
|
169 |
+
out["s"] = s
|
170 |
+
|
171 |
+
if "m" in self.decoders:
|
172 |
+
if s is not None and d is not None:
|
173 |
+
cond = self.make_m_cond(d, s, x)
|
174 |
+
m = self.mask(z=z, cond=cond)
|
175 |
+
out["m"] = m
|
176 |
+
|
177 |
+
return out
|
178 |
+
|
179 |
+
def sample_painter_z(self, batch_size, device, force_half=False):
|
180 |
+
if self.opts.gen.p.no_z:
|
181 |
+
return None
|
182 |
+
|
183 |
+
z = torch.empty(
|
184 |
+
batch_size,
|
185 |
+
self.opts.gen.p.latent_dim,
|
186 |
+
self.painter.z_h,
|
187 |
+
self.painter.z_w,
|
188 |
+
device=device,
|
189 |
+
).normal_(mean=0, std=1.0)
|
190 |
+
|
191 |
+
if force_half:
|
192 |
+
z = z.half()
|
193 |
+
|
194 |
+
return z
|
195 |
+
|
196 |
+
def make_m_cond(self, d, s, x=None):
|
197 |
+
"""
|
198 |
+
Create the masker's conditioning input when using spade from the
|
199 |
+
d and s predictions and from the input x when cond_nc == 15.
|
200 |
+
|
201 |
+
d and s are assumed to have the the same spatial resolution.
|
202 |
+
if cond_nc == 15 then x is interpolated to match that dimension.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
d (torch.Tensor): Raw depth prediction (B1HW)
|
206 |
+
s (torch.Tensor): Raw segmentation prediction (BCHW)
|
207 |
+
x (torch.Tensor, optional): Input tensor (B3hW). Mandatory
|
208 |
+
when opts.gen.m.spade.cond_nc == 15
|
209 |
+
|
210 |
+
Raises:
|
211 |
+
ValueError: opts.gen.m.spade.cond_nc == 15 but x is None
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
torch.Tensor: B x cond_nc x H x W conditioning tensor.
|
215 |
+
"""
|
216 |
+
if self.opts.gen.m.spade.detach:
|
217 |
+
d = d.detach()
|
218 |
+
s = s.detach()
|
219 |
+
cats = [normalize(d), softmax(s, dim=1)]
|
220 |
+
if self.opts.gen.m.spade.cond_nc == 15:
|
221 |
+
if x is None:
|
222 |
+
raise ValueError(
|
223 |
+
"When using spade for the Masker with 15 channels,"
|
224 |
+
+ " x MUST be provided"
|
225 |
+
)
|
226 |
+
cats += [
|
227 |
+
F.interpolate(x, s.shape[-2:], mode="bilinear", align_corners=True)
|
228 |
+
]
|
229 |
+
|
230 |
+
return torch.cat(cats, dim=1)
|
231 |
+
|
232 |
+
def mask(self, x=None, z=None, cond=None, z_depth=None, sigmoid=True):
|
233 |
+
"""
|
234 |
+
Create a mask from either an input x or a latent vector z.
|
235 |
+
Optionally if the Masker has a spade architecture the conditioning tensor
|
236 |
+
may be provided (cond). Default behavior applies an element-wise
|
237 |
+
sigmoid, but can be deactivated (sigmoid=False).
|
238 |
+
|
239 |
+
At least one of x or z must be provided (i.e. not None).
|
240 |
+
If the Masker has a spade architecture and cond_nc == 15 then x cannot
|
241 |
+
be None.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
x (torch.Tensor, optional): Input tensor B3HW. Defaults to None.
|
245 |
+
z (list, optional): High and Low level features of the encoder.
|
246 |
+
Will be computed if None. Defaults to None.
|
247 |
+
cond ([type], optional): [description]. Defaults to None.
|
248 |
+
sigmoid (bool, optional): [description]. Defaults to True.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
torch.Tensor: B1HW mask tensor
|
252 |
+
"""
|
253 |
+
assert x is not None or z is not None
|
254 |
+
if z is None:
|
255 |
+
z = self.encode(x)
|
256 |
+
|
257 |
+
if cond is None and self.opts.gen.m.use_spade:
|
258 |
+
assert "s" in self.opts.tasks and "d" in self.opts.tasks
|
259 |
+
with torch.no_grad():
|
260 |
+
d_pred, z_d = self.decoders["d"](z)
|
261 |
+
s_pred = self.decoders["s"](z, z_d)
|
262 |
+
cond = self.make_m_cond(d_pred, s_pred, x)
|
263 |
+
if z_depth is None and self.opts.gen.m.use_dada:
|
264 |
+
assert "d" in self.opts.tasks
|
265 |
+
with torch.no_grad():
|
266 |
+
_, z_depth = self.decoders["d"](z)
|
267 |
+
|
268 |
+
if cond is not None:
|
269 |
+
device = z[0].device if isinstance(z, (tuple, list)) else z.device
|
270 |
+
cond = cond.to(device)
|
271 |
+
|
272 |
+
logits = self.decoders["m"](z, cond, z_depth)
|
273 |
+
|
274 |
+
if not sigmoid:
|
275 |
+
return logits
|
276 |
+
|
277 |
+
return torch.sigmoid(logits)
|
278 |
+
|
279 |
+
def paint(self, m, x, no_paste=False):
|
280 |
+
"""
|
281 |
+
Paints given a mask and an image
|
282 |
+
calls painter(z, x * (1.0 - m))
|
283 |
+
Mask has 1s where water should be painted
|
284 |
+
|
285 |
+
Args:
|
286 |
+
m (torch.Tensor): Mask
|
287 |
+
x (torch.Tensor): Image to paint
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
torch.Tensor: painted image
|
291 |
+
"""
|
292 |
+
z_paint = self.sample_painter_z(x.shape[0], x.device)
|
293 |
+
m = m.to(x.dtype)
|
294 |
+
fake = self.painter(z_paint, x * (1.0 - m))
|
295 |
+
if self.opts.gen.p.paste_original_content and not no_paste:
|
296 |
+
return x * (1.0 - m) + fake * m
|
297 |
+
return fake
|
298 |
+
|
299 |
+
def paint_cloudy(self, m, x, s, sky_idx=9, res=(8, 8), weight=0.8):
|
300 |
+
"""
|
301 |
+
Paints x with water in m through an intermediary cloudy image
|
302 |
+
where the sky has been replaced with perlin noise to imitate clouds.
|
303 |
+
|
304 |
+
The intermediary cloudy image is only used to control the painter's
|
305 |
+
painting mode, probing it with a cloudy input.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
m (torch.Tensor): water mask
|
309 |
+
x (torch.Tensor): input tensor
|
310 |
+
s (torch.Tensor): segmentation prediction (BCHW)
|
311 |
+
sky_idx (int, optional): Index of the sky class along s's C dimension.
|
312 |
+
Defaults to 9.
|
313 |
+
res (tuple, optional): Perlin noise spatial resolution. Defaults to (8, 8).
|
314 |
+
weight (float, optional): Intermediate image's cloud proportion
|
315 |
+
(w * cloud + (1-w) * original_sky). Defaults to 0.8.
|
316 |
+
|
317 |
+
Returns:
|
318 |
+
torch.Tensor: painted image with original content pasted.
|
319 |
+
"""
|
320 |
+
sky_mask = (
|
321 |
+
torch.argmax(
|
322 |
+
F.interpolate(s, x.shape[-2:], mode="bilinear"), dim=1, keepdim=True
|
323 |
+
)
|
324 |
+
== sky_idx
|
325 |
+
).to(x.dtype)
|
326 |
+
noised_x = mix_noise(x, sky_mask, res=res, weight=weight).to(x.dtype)
|
327 |
+
fake = self.paint(m, noised_x, no_paste=True)
|
328 |
+
return x * (1.0 - m) + fake * m
|
329 |
+
|
330 |
+
def depth(self, x=None, z=None, return_z_depth=False):
|
331 |
+
"""
|
332 |
+
Compute the depth head's output
|
333 |
+
|
334 |
+
Args:
|
335 |
+
x (torch.Tensor, optional): Input B3HW tensor. Defaults to None.
|
336 |
+
z (list, optional): High and Low level features of the encoder.
|
337 |
+
Defaults to None.
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
torch.Tensor: B1HW tensor of depth predictions
|
341 |
+
"""
|
342 |
+
assert x is not None or z is not None
|
343 |
+
assert not (x is not None and z is not None)
|
344 |
+
if z is None:
|
345 |
+
z = self.encode(x)
|
346 |
+
depth, z_depth = self.decoders["d"](z)
|
347 |
+
|
348 |
+
if depth.shape[1] > 1:
|
349 |
+
depth = torch.argmax(depth, dim=1)
|
350 |
+
depth = depth / depth.max()
|
351 |
+
|
352 |
+
if return_z_depth:
|
353 |
+
return depth, z_depth
|
354 |
+
|
355 |
+
return depth
|
356 |
+
|
357 |
+
def load_val_painter(self):
|
358 |
+
"""
|
359 |
+
Loads a validation painter if available in opts.val.val_painter
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
bool: operation success status
|
363 |
+
"""
|
364 |
+
try:
|
365 |
+
# key exists in opts
|
366 |
+
assert self.opts.val.val_painter
|
367 |
+
|
368 |
+
# path exists
|
369 |
+
ckpt_path = Path(self.opts.val.val_painter).resolve()
|
370 |
+
assert ckpt_path.exists()
|
371 |
+
|
372 |
+
# path is a checkpoint path
|
373 |
+
assert ckpt_path.is_file()
|
374 |
+
|
375 |
+
# opts are available in that path
|
376 |
+
opts_path = ckpt_path.parent.parent / "opts.yaml"
|
377 |
+
assert opts_path.exists()
|
378 |
+
|
379 |
+
# load opts
|
380 |
+
with opts_path.open("r") as f:
|
381 |
+
val_painter_opts = Dict(yaml.safe_load(f))
|
382 |
+
|
383 |
+
# load checkpoint
|
384 |
+
state_dict = torch.load(ckpt_path)
|
385 |
+
|
386 |
+
# create dummy painter from loaded opts
|
387 |
+
painter = create_painter(val_painter_opts)
|
388 |
+
|
389 |
+
# load state-dict in the dummy painter
|
390 |
+
painter.load_state_dict(
|
391 |
+
{k.replace("painter.", ""): v for k, v in state_dict["G"].items()}
|
392 |
+
)
|
393 |
+
|
394 |
+
# send to current device in evaluation mode
|
395 |
+
device = next(self.parameters()).device
|
396 |
+
self.painter = painter.eval().to(device)
|
397 |
+
|
398 |
+
# disable gradients
|
399 |
+
for p in self.painter.parameters():
|
400 |
+
p.requires_grad = False
|
401 |
+
|
402 |
+
# success
|
403 |
+
print(" - Loaded validation-only painter")
|
404 |
+
return True
|
405 |
+
|
406 |
+
except Exception as e:
|
407 |
+
# something happened, aborting gracefully
|
408 |
+
print(traceback.format_exc())
|
409 |
+
print(e)
|
410 |
+
print(">>> WARNING: error (^) in load_val_painter, aborting.")
|
411 |
+
return False
|
climategan/logger.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchvision.utils as vutils
|
6 |
+
from addict import Dict
|
7 |
+
from PIL import Image
|
8 |
+
from torch.nn.functional import interpolate, sigmoid
|
9 |
+
|
10 |
+
from climategan.data import decode_segmap_merged_labels
|
11 |
+
from climategan.tutils import (
|
12 |
+
all_texts_to_tensors,
|
13 |
+
decode_bucketed_depth,
|
14 |
+
normalize_tensor,
|
15 |
+
write_architecture,
|
16 |
+
)
|
17 |
+
from climategan.utils import flatten_opts
|
18 |
+
|
19 |
+
|
20 |
+
class Logger:
|
21 |
+
def __init__(self, trainer):
|
22 |
+
self.losses = Dict()
|
23 |
+
self.time = Dict()
|
24 |
+
self.trainer = trainer
|
25 |
+
self.global_step = 0
|
26 |
+
self.epoch = 0
|
27 |
+
|
28 |
+
def log_comet_images(self, mode, domain, minimal=False, all_only=False):
|
29 |
+
trainer = self.trainer
|
30 |
+
save_images = {}
|
31 |
+
all_images = []
|
32 |
+
n_all_ims = None
|
33 |
+
all_legends = ["Input"]
|
34 |
+
task_legends = {}
|
35 |
+
|
36 |
+
if domain not in trainer.display_images[mode]:
|
37 |
+
return
|
38 |
+
|
39 |
+
# --------------------
|
40 |
+
# ----- Masker -----
|
41 |
+
# --------------------
|
42 |
+
n_ims = len(trainer.display_images[mode][domain])
|
43 |
+
print(" " * 60, end="\r")
|
44 |
+
if domain != "rf":
|
45 |
+
for j, display_dict in enumerate(trainer.display_images[mode][domain]):
|
46 |
+
|
47 |
+
print(f"Inferring sample {mode} {domain} {j+1}/{n_ims}", end="\r")
|
48 |
+
|
49 |
+
x = display_dict["data"]["x"].unsqueeze(0).to(trainer.device)
|
50 |
+
z = trainer.G.encode(x)
|
51 |
+
|
52 |
+
s_pred = decoded_s_pred = d_pred = z_depth = None
|
53 |
+
for k, task in enumerate(["d", "s", "m"]):
|
54 |
+
|
55 |
+
if (
|
56 |
+
task not in display_dict["data"]
|
57 |
+
or task not in trainer.opts.tasks
|
58 |
+
):
|
59 |
+
continue
|
60 |
+
|
61 |
+
task_legend = ["Input"]
|
62 |
+
target = display_dict["data"][task]
|
63 |
+
target = target.unsqueeze(0).to(trainer.device)
|
64 |
+
task_saves = []
|
65 |
+
|
66 |
+
if task not in save_images:
|
67 |
+
save_images[task] = []
|
68 |
+
|
69 |
+
prediction = None
|
70 |
+
if task == "m":
|
71 |
+
cond = None
|
72 |
+
if s_pred is not None and d_pred is not None:
|
73 |
+
cond = trainer.G.make_m_cond(d_pred, s_pred, x)
|
74 |
+
|
75 |
+
prediction = trainer.G.decoders[task](z, cond, z_depth)
|
76 |
+
elif task == "d":
|
77 |
+
prediction, z_depth = trainer.G.decoders[task](z)
|
78 |
+
elif task == "s":
|
79 |
+
prediction = trainer.G.decoders[task](z, z_depth)
|
80 |
+
|
81 |
+
if task == "s":
|
82 |
+
# Log fire
|
83 |
+
wildfire_tens = trainer.compute_fire(x, prediction)
|
84 |
+
task_saves.append(wildfire_tens)
|
85 |
+
task_legend.append("Wildfire")
|
86 |
+
# Log seg output
|
87 |
+
s_pred = prediction.clone()
|
88 |
+
target = (
|
89 |
+
decode_segmap_merged_labels(target, domain, True)
|
90 |
+
.float()
|
91 |
+
.to(trainer.device)
|
92 |
+
)
|
93 |
+
prediction = (
|
94 |
+
decode_segmap_merged_labels(prediction, domain, False)
|
95 |
+
.float()
|
96 |
+
.to(trainer.device)
|
97 |
+
)
|
98 |
+
decoded_s_pred = prediction
|
99 |
+
task_saves.append(target)
|
100 |
+
task_legend.append("Target Segmentation")
|
101 |
+
|
102 |
+
elif task == "m":
|
103 |
+
prediction = sigmoid(prediction).repeat(1, 3, 1, 1)
|
104 |
+
task_saves.append(x * (1.0 - prediction))
|
105 |
+
if not minimal:
|
106 |
+
task_saves.append(
|
107 |
+
x * (1.0 - (prediction > 0.1).to(torch.int))
|
108 |
+
)
|
109 |
+
task_saves.append(
|
110 |
+
x * (1.0 - (prediction > 0.5).to(torch.int))
|
111 |
+
)
|
112 |
+
|
113 |
+
task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1)))
|
114 |
+
task_legend.append("Masked input")
|
115 |
+
|
116 |
+
if not minimal:
|
117 |
+
task_legend.append("Masked input (>0.1)")
|
118 |
+
task_legend.append("Masked input (>0.5)")
|
119 |
+
|
120 |
+
task_legend.append("Masked input (target)")
|
121 |
+
# dummy pixels to fool scaling and preserve mask range
|
122 |
+
prediction[:, :, 0, 0] = 1.0
|
123 |
+
prediction[:, :, -1, -1] = 0.0
|
124 |
+
|
125 |
+
elif task == "d":
|
126 |
+
# prediction is a log depth tensor
|
127 |
+
d_pred = prediction
|
128 |
+
target = normalize_tensor(target) * 255
|
129 |
+
if prediction.shape[1] > 1:
|
130 |
+
prediction = decode_bucketed_depth(
|
131 |
+
prediction, self.trainer.opts
|
132 |
+
)
|
133 |
+
smogged = self.trainer.compute_smog(
|
134 |
+
x, d=prediction, s=decoded_s_pred, use_sky_seg=False
|
135 |
+
)
|
136 |
+
prediction = normalize_tensor(prediction)
|
137 |
+
prediction = prediction.repeat(1, 3, 1, 1)
|
138 |
+
task_saves.append(smogged)
|
139 |
+
task_legend.append("Smogged")
|
140 |
+
task_saves.append(target.repeat(1, 3, 1, 1))
|
141 |
+
task_legend.append("Depth target")
|
142 |
+
|
143 |
+
task_saves.append(prediction)
|
144 |
+
task_legend.append(f"Predicted {task}")
|
145 |
+
|
146 |
+
save_images[task].append(x.cpu().detach())
|
147 |
+
if k == 0:
|
148 |
+
all_images.append(save_images[task][-1])
|
149 |
+
|
150 |
+
task_legends[task] = task_legend
|
151 |
+
if j == 0:
|
152 |
+
all_legends += task_legend[1:]
|
153 |
+
|
154 |
+
for im in task_saves:
|
155 |
+
save_images[task].append(im.cpu().detach())
|
156 |
+
all_images.append(save_images[task][-1])
|
157 |
+
|
158 |
+
if j == 0:
|
159 |
+
n_all_ims = len(all_images)
|
160 |
+
|
161 |
+
if not all_only:
|
162 |
+
for task in save_images.keys():
|
163 |
+
# Write images:
|
164 |
+
self.upload_images(
|
165 |
+
image_outputs=save_images[task],
|
166 |
+
mode=mode,
|
167 |
+
domain=domain,
|
168 |
+
task=task,
|
169 |
+
im_per_row=trainer.opts.comet.im_per_row.get(task, 4),
|
170 |
+
rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
|
171 |
+
legends=task_legends[task],
|
172 |
+
)
|
173 |
+
|
174 |
+
if len(save_images) > 1:
|
175 |
+
self.upload_images(
|
176 |
+
image_outputs=all_images,
|
177 |
+
mode=mode,
|
178 |
+
domain=domain,
|
179 |
+
task="all",
|
180 |
+
im_per_row=n_all_ims,
|
181 |
+
rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
|
182 |
+
legends=all_legends,
|
183 |
+
)
|
184 |
+
# ---------------------
|
185 |
+
# ----- Painter -----
|
186 |
+
# ---------------------
|
187 |
+
else:
|
188 |
+
# in the rf domain display_size may be different from fid.n_images
|
189 |
+
limit = trainer.opts.comet.display_size
|
190 |
+
image_outputs = []
|
191 |
+
legends = []
|
192 |
+
for im_set in trainer.display_images[mode][domain][:limit]:
|
193 |
+
x = im_set["data"]["x"].unsqueeze(0).to(trainer.device)
|
194 |
+
m = im_set["data"]["m"].unsqueeze(0).to(trainer.device)
|
195 |
+
|
196 |
+
prediction = trainer.G.paint(m, x)
|
197 |
+
|
198 |
+
image_outputs.append(x * (1.0 - m))
|
199 |
+
image_outputs.append(prediction)
|
200 |
+
image_outputs.append(x)
|
201 |
+
image_outputs.append(prediction * m)
|
202 |
+
if not legends:
|
203 |
+
legends.append("Masked Input")
|
204 |
+
legends.append("Painted Input")
|
205 |
+
legends.append("Input")
|
206 |
+
legends.append("Isolated Water")
|
207 |
+
# Write images
|
208 |
+
self.upload_images(
|
209 |
+
image_outputs=image_outputs,
|
210 |
+
mode=mode,
|
211 |
+
domain=domain,
|
212 |
+
task="painter",
|
213 |
+
im_per_row=trainer.opts.comet.im_per_row.get("p", 4),
|
214 |
+
rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
|
215 |
+
legends=legends,
|
216 |
+
)
|
217 |
+
|
218 |
+
return 0
|
219 |
+
|
220 |
+
def log_losses(self, model_to_update="G", mode="train"):
|
221 |
+
"""Logs metrics on comet.ml
|
222 |
+
|
223 |
+
Args:
|
224 |
+
model_to_update (str, optional): One of "G", "D". Defaults to "G".
|
225 |
+
"""
|
226 |
+
trainer = self.trainer
|
227 |
+
loss_names = {"G": "gen", "D": "disc"}
|
228 |
+
|
229 |
+
if trainer.opts.train.log_level < 1:
|
230 |
+
return
|
231 |
+
|
232 |
+
if trainer.exp is None:
|
233 |
+
return
|
234 |
+
|
235 |
+
assert model_to_update in {
|
236 |
+
"G",
|
237 |
+
"D",
|
238 |
+
}, "unknown model to log losses {}".format(model_to_update)
|
239 |
+
|
240 |
+
loss_to_update = self.losses[loss_names[model_to_update]]
|
241 |
+
|
242 |
+
losses = loss_to_update.copy()
|
243 |
+
|
244 |
+
if trainer.opts.train.log_level == 1:
|
245 |
+
# Only log aggregated losses: delete other keys in losses
|
246 |
+
for k in loss_to_update:
|
247 |
+
if k not in {"masker", "total_loss", "painter"}:
|
248 |
+
del losses[k]
|
249 |
+
# convert losses into a single-level dictionnary
|
250 |
+
|
251 |
+
losses = flatten_opts(losses)
|
252 |
+
trainer.exp.log_metrics(
|
253 |
+
losses, prefix=f"{model_to_update}_{mode}", step=self.global_step
|
254 |
+
)
|
255 |
+
|
256 |
+
def log_learning_rates(self):
|
257 |
+
if self.trainer.exp is None:
|
258 |
+
return
|
259 |
+
lrs = {}
|
260 |
+
trainer = self.trainer
|
261 |
+
if trainer.g_scheduler is not None:
|
262 |
+
for name, lr in zip(
|
263 |
+
trainer.lr_names["G"], trainer.g_scheduler.get_last_lr()
|
264 |
+
):
|
265 |
+
lrs[f"lr_G_{name}"] = lr
|
266 |
+
if trainer.d_scheduler is not None:
|
267 |
+
for name, lr in zip(
|
268 |
+
trainer.lr_names["D"], trainer.d_scheduler.get_last_lr()
|
269 |
+
):
|
270 |
+
lrs[f"lr_D_{name}"] = lr
|
271 |
+
|
272 |
+
trainer.exp.log_metrics(lrs, step=self.global_step)
|
273 |
+
|
274 |
+
def log_step_time(self, time):
|
275 |
+
"""Logs step-time on comet.ml
|
276 |
+
|
277 |
+
Args:
|
278 |
+
step_time (float): step-time in seconds
|
279 |
+
"""
|
280 |
+
if self.trainer.exp:
|
281 |
+
self.trainer.exp.log_metric(
|
282 |
+
"step-time", time - self.time.step_start, step=self.global_step
|
283 |
+
)
|
284 |
+
|
285 |
+
def log_epoch_time(self, time):
|
286 |
+
"""Logs step-time on comet.ml
|
287 |
+
|
288 |
+
Args:
|
289 |
+
step_time (float): step-time in seconds
|
290 |
+
"""
|
291 |
+
if self.trainer.exp:
|
292 |
+
self.trainer.exp.log_metric(
|
293 |
+
"epoch-time", time - self.time.epoch_start, step=self.global_step
|
294 |
+
)
|
295 |
+
|
296 |
+
def log_comet_combined_images(self, mode, domain):
|
297 |
+
|
298 |
+
trainer = self.trainer
|
299 |
+
image_outputs = []
|
300 |
+
legends = []
|
301 |
+
im_per_row = 0
|
302 |
+
for i, im_set in enumerate(trainer.display_images[mode][domain]):
|
303 |
+
x = im_set["data"]["x"].unsqueeze(0).to(trainer.device)
|
304 |
+
# m = im_set["data"]["m"].unsqueeze(0).to(trainer.device)
|
305 |
+
|
306 |
+
m = trainer.G.mask(x=x)
|
307 |
+
m_bin = (m > 0.5).to(m.dtype)
|
308 |
+
prediction = trainer.G.paint(m, x)
|
309 |
+
prediction_bin = trainer.G.paint(m_bin, x)
|
310 |
+
|
311 |
+
image_outputs.append(x)
|
312 |
+
legends.append("Input")
|
313 |
+
image_outputs.append(x * (1.0 - m))
|
314 |
+
legends.append("Soft Masked Input")
|
315 |
+
image_outputs.append(prediction)
|
316 |
+
legends.append("Painted")
|
317 |
+
image_outputs.append(prediction * m)
|
318 |
+
legends.append("Soft Masked Painted")
|
319 |
+
image_outputs.append(x * (1.0 - m_bin))
|
320 |
+
legends.append("Binary (0.5) Masked Input")
|
321 |
+
image_outputs.append(prediction_bin)
|
322 |
+
legends.append("Binary (0.5) Painted")
|
323 |
+
image_outputs.append(prediction_bin * m_bin)
|
324 |
+
legends.append("Binary (0.5) Masked Painted")
|
325 |
+
|
326 |
+
if i == 0:
|
327 |
+
im_per_row = len(image_outputs)
|
328 |
+
# Upload images
|
329 |
+
self.upload_images(
|
330 |
+
image_outputs=image_outputs,
|
331 |
+
mode=mode,
|
332 |
+
domain=domain,
|
333 |
+
task="combined",
|
334 |
+
im_per_row=im_per_row or 7,
|
335 |
+
rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
|
336 |
+
legends=legends,
|
337 |
+
)
|
338 |
+
|
339 |
+
return 0
|
340 |
+
|
341 |
+
def upload_images(
|
342 |
+
self,
|
343 |
+
image_outputs,
|
344 |
+
mode,
|
345 |
+
domain,
|
346 |
+
task,
|
347 |
+
im_per_row=3,
|
348 |
+
rows_per_log=5,
|
349 |
+
legends=[],
|
350 |
+
):
|
351 |
+
"""
|
352 |
+
Save output image
|
353 |
+
|
354 |
+
Args:
|
355 |
+
image_outputs (list(torch.Tensor)): all the images to log
|
356 |
+
mode (str): train or val
|
357 |
+
domain (str): current domain
|
358 |
+
task (str): current task
|
359 |
+
im_per_row (int, optional): umber of images to be displayed per row.
|
360 |
+
Typically, for a given task: 3 because [input prediction, target].
|
361 |
+
Defaults to 3.
|
362 |
+
rows_per_log (int, optional): Number of rows (=samples) per uploaded image.
|
363 |
+
Defaults to 5.
|
364 |
+
comet_exp (comet_ml.Experiment, optional): experiment to use.
|
365 |
+
Defaults to None.
|
366 |
+
"""
|
367 |
+
trainer = self.trainer
|
368 |
+
if trainer.exp is None:
|
369 |
+
return
|
370 |
+
curr_iter = self.global_step
|
371 |
+
nb_per_log = im_per_row * rows_per_log
|
372 |
+
n_logs = len(image_outputs) // nb_per_log + 1
|
373 |
+
|
374 |
+
header = None
|
375 |
+
if len(legends) == im_per_row and all(isinstance(t, str) for t in legends):
|
376 |
+
header_width = max(im.shape[-1] for im in image_outputs)
|
377 |
+
headers = all_texts_to_tensors(legends, width=header_width)
|
378 |
+
header = torch.cat(headers, dim=-1)
|
379 |
+
|
380 |
+
for logidx in range(n_logs):
|
381 |
+
print(" " * 100, end="\r", flush=True)
|
382 |
+
print(
|
383 |
+
"Uploading images for {} {} {} {}/{}".format(
|
384 |
+
mode, domain, task, logidx + 1, n_logs
|
385 |
+
),
|
386 |
+
end="...",
|
387 |
+
flush=True,
|
388 |
+
)
|
389 |
+
ims = image_outputs[logidx * nb_per_log : (logidx + 1) * nb_per_log]
|
390 |
+
if not ims:
|
391 |
+
continue
|
392 |
+
|
393 |
+
ims = self.upsample(ims)
|
394 |
+
ims = torch.stack([im.squeeze() for im in ims]).squeeze()
|
395 |
+
image_grid = vutils.make_grid(
|
396 |
+
ims, nrow=im_per_row, normalize=True, scale_each=True, padding=0
|
397 |
+
)
|
398 |
+
|
399 |
+
if header is not None:
|
400 |
+
image_grid = torch.cat(
|
401 |
+
[header.to(image_grid.device), image_grid], dim=1
|
402 |
+
)
|
403 |
+
|
404 |
+
image_grid = image_grid.permute(1, 2, 0).cpu().numpy()
|
405 |
+
trainer.exp.log_image(
|
406 |
+
Image.fromarray((image_grid * 255).astype(np.uint8)),
|
407 |
+
name=f"{mode}_{domain}_{task}_{str(curr_iter)}_#{logidx}",
|
408 |
+
step=curr_iter,
|
409 |
+
)
|
410 |
+
|
411 |
+
def upsample(self, ims):
|
412 |
+
h = max(im.shape[-2] for im in ims)
|
413 |
+
w = max(im.shape[-1] for im in ims)
|
414 |
+
new_ims = []
|
415 |
+
for im in ims:
|
416 |
+
im = interpolate(im, (h, w), mode="bilinear")
|
417 |
+
new_ims.append(im)
|
418 |
+
return new_ims
|
419 |
+
|
420 |
+
def padd(self, ims):
|
421 |
+
h = max(im.shape[-2] for im in ims)
|
422 |
+
w = max(im.shape[-1] for im in ims)
|
423 |
+
new_ims = []
|
424 |
+
for im in ims:
|
425 |
+
ih = im.shape[-2]
|
426 |
+
iw = im.shape[-1]
|
427 |
+
if ih != h or iw != w:
|
428 |
+
padded = torch.zeros(im.shape[-3], h, w)
|
429 |
+
padded[
|
430 |
+
:, (h - ih) // 2 : (h + ih) // 2, (w - iw) // 2 : (w + iw) // 2
|
431 |
+
] = im
|
432 |
+
new_ims.append(padded)
|
433 |
+
else:
|
434 |
+
new_ims.append(im)
|
435 |
+
|
436 |
+
return new_ims
|
437 |
+
|
438 |
+
def log_architecture(self):
|
439 |
+
write_architecture(self.trainer)
|
440 |
+
|
441 |
+
if self.trainer.exp is None:
|
442 |
+
return
|
443 |
+
|
444 |
+
for f in Path(self.trainer.opts.output_path).glob("archi*.txt"):
|
445 |
+
self.trainer.exp.log_asset(str(f), overwrite=True)
|
climategan/losses.py
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Define all losses. When possible, as inheriting from nn.Module
|
2 |
+
To send predictions to target.device
|
3 |
+
"""
|
4 |
+
from random import random as rand
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torchvision import models
|
11 |
+
|
12 |
+
|
13 |
+
class GANLoss(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
use_lsgan=True,
|
17 |
+
target_real_label=1.0,
|
18 |
+
target_fake_label=0.0,
|
19 |
+
soft_shift=0.0,
|
20 |
+
flip_prob=0.0,
|
21 |
+
verbose=0,
|
22 |
+
):
|
23 |
+
"""Defines the GAN loss which uses either LSGAN or the regular GAN.
|
24 |
+
When LSGAN is used, it is basically same as MSELoss,
|
25 |
+
but it abstracts away the need to create the target label tensor
|
26 |
+
that has the same size as the input +
|
27 |
+
|
28 |
+
* label smoothing: target_real_label=0.75
|
29 |
+
* label flipping: flip_prob > 0.
|
30 |
+
|
31 |
+
source: https://github.com/sangwoomo/instagan/blob
|
32 |
+
/b67e9008fcdd6c41652f8805f0b36bcaa8b632d6/models/networks.py
|
33 |
+
|
34 |
+
Args:
|
35 |
+
use_lsgan (bool, optional): Use MSE or BCE. Defaults to True.
|
36 |
+
target_real_label (float, optional): Value for the real target.
|
37 |
+
Defaults to 1.0.
|
38 |
+
target_fake_label (float, optional): Value for the fake target.
|
39 |
+
Defaults to 0.0.
|
40 |
+
flip_prob (float, optional): Probability of flipping the label
|
41 |
+
(use for real target in Discriminator only). Defaults to 0.0.
|
42 |
+
"""
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.soft_shift = soft_shift
|
46 |
+
self.verbose = verbose
|
47 |
+
|
48 |
+
self.register_buffer("real_label", torch.tensor(target_real_label))
|
49 |
+
self.register_buffer("fake_label", torch.tensor(target_fake_label))
|
50 |
+
if use_lsgan:
|
51 |
+
self.loss = nn.MSELoss()
|
52 |
+
else:
|
53 |
+
self.loss = nn.BCEWithLogitsLoss()
|
54 |
+
self.flip_prob = flip_prob
|
55 |
+
|
56 |
+
def get_target_tensor(self, input, target_is_real):
|
57 |
+
soft_change = torch.FloatTensor(1).uniform_(0, self.soft_shift)
|
58 |
+
if self.verbose > 0:
|
59 |
+
print("GANLoss sampled soft_change:", soft_change.item())
|
60 |
+
if target_is_real:
|
61 |
+
target_tensor = self.real_label - soft_change
|
62 |
+
else:
|
63 |
+
target_tensor = self.fake_label + soft_change
|
64 |
+
return target_tensor.expand_as(input)
|
65 |
+
|
66 |
+
def __call__(self, input, target_is_real, *args, **kwargs):
|
67 |
+
r = rand()
|
68 |
+
if isinstance(input, list):
|
69 |
+
loss = 0
|
70 |
+
for pred_i in input:
|
71 |
+
if isinstance(pred_i, list):
|
72 |
+
pred_i = pred_i[-1]
|
73 |
+
if r < self.flip_prob:
|
74 |
+
target_is_real = not target_is_real
|
75 |
+
target_tensor = self.get_target_tensor(pred_i, target_is_real)
|
76 |
+
loss_tensor = self.loss(pred_i, target_tensor.to(pred_i.device))
|
77 |
+
loss += loss_tensor
|
78 |
+
return loss / len(input)
|
79 |
+
else:
|
80 |
+
if r < self.flip_prob:
|
81 |
+
target_is_real = not target_is_real
|
82 |
+
target_tensor = self.get_target_tensor(input, target_is_real)
|
83 |
+
return self.loss(input, target_tensor.to(input.device))
|
84 |
+
|
85 |
+
|
86 |
+
class FeatMatchLoss(nn.Module):
|
87 |
+
def __init__(self):
|
88 |
+
super().__init__()
|
89 |
+
self.criterionFeat = nn.L1Loss()
|
90 |
+
|
91 |
+
def __call__(self, pred_real, pred_fake):
|
92 |
+
# pred_{real, fake} are lists of features
|
93 |
+
num_D = len(pred_fake)
|
94 |
+
GAN_Feat_loss = 0.0
|
95 |
+
for i in range(num_D): # for each discriminator
|
96 |
+
# last output is the final prediction, so we exclude it
|
97 |
+
num_intermediate_outputs = len(pred_fake[i]) - 1
|
98 |
+
for j in range(num_intermediate_outputs): # for each layer output
|
99 |
+
unweighted_loss = self.criterionFeat(
|
100 |
+
pred_fake[i][j], pred_real[i][j].detach()
|
101 |
+
)
|
102 |
+
GAN_Feat_loss += unweighted_loss / num_D
|
103 |
+
return GAN_Feat_loss
|
104 |
+
|
105 |
+
|
106 |
+
class CrossEntropy(nn.Module):
|
107 |
+
def __init__(self):
|
108 |
+
super().__init__()
|
109 |
+
self.loss = nn.CrossEntropyLoss()
|
110 |
+
|
111 |
+
def __call__(self, logits, target):
|
112 |
+
return self.loss(logits, target.to(logits.device).long())
|
113 |
+
|
114 |
+
|
115 |
+
class TravelLoss(nn.Module):
|
116 |
+
def __init__(self, eps=1e-12):
|
117 |
+
super().__init__()
|
118 |
+
self.eps = eps
|
119 |
+
|
120 |
+
def cosine_loss(self, real, fake):
|
121 |
+
norm_real = torch.norm(real, p=2, dim=1)[:, None]
|
122 |
+
norm_fake = torch.norm(fake, p=2, dim=1)[:, None]
|
123 |
+
mat_real = real / norm_real
|
124 |
+
mat_fake = fake / norm_fake
|
125 |
+
mat_real = torch.max(mat_real, self.eps * torch.ones_like(mat_real))
|
126 |
+
mat_fake = torch.max(mat_fake, self.eps * torch.ones_like(mat_fake))
|
127 |
+
# compute only the diagonal of the matrix multiplication
|
128 |
+
return torch.einsum("ij, ji -> i", mat_fake, mat_real).sum()
|
129 |
+
|
130 |
+
def __call__(self, S_real, S_fake):
|
131 |
+
self.v_real = []
|
132 |
+
self.v_fake = []
|
133 |
+
for i in range(len(S_real)):
|
134 |
+
for j in range(i):
|
135 |
+
self.v_real.append((S_real[i] - S_real[j])[None, :])
|
136 |
+
self.v_fake.append((S_fake[i] - S_fake[j])[None, :])
|
137 |
+
self.v_real_t = torch.cat(self.v_real, dim=0)
|
138 |
+
self.v_fake_t = torch.cat(self.v_fake, dim=0)
|
139 |
+
return self.cosine_loss(self.v_real_t, self.v_fake_t)
|
140 |
+
|
141 |
+
|
142 |
+
class TVLoss(nn.Module):
|
143 |
+
"""Total Variational Regularization: Penalizes differences in
|
144 |
+
neighboring pixel values
|
145 |
+
|
146 |
+
source:
|
147 |
+
https://github.com/jxgu1016/Total_Variation_Loss.pytorch/blob/master/TVLoss.py
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self, tvloss_weight=1):
|
151 |
+
"""
|
152 |
+
Args:
|
153 |
+
TVLoss_weight (int, optional): [lambda i.e. weight for loss]. Defaults to 1.
|
154 |
+
"""
|
155 |
+
super(TVLoss, self).__init__()
|
156 |
+
self.tvloss_weight = tvloss_weight
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
batch_size = x.size()[0]
|
160 |
+
h_x = x.size()[2]
|
161 |
+
w_x = x.size()[3]
|
162 |
+
count_h = self._tensor_size(x[:, :, 1:, :])
|
163 |
+
count_w = self._tensor_size(x[:, :, :, 1:])
|
164 |
+
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, : h_x - 1, :]), 2).sum()
|
165 |
+
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, : w_x - 1]), 2).sum()
|
166 |
+
return self.tvloss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
|
167 |
+
|
168 |
+
def _tensor_size(self, t):
|
169 |
+
return t.size()[1] * t.size()[2] * t.size()[3]
|
170 |
+
|
171 |
+
|
172 |
+
class MinentLoss(nn.Module):
|
173 |
+
"""
|
174 |
+
Loss for the minimization of the entropy map
|
175 |
+
Source for version 1: https://github.com/valeoai/ADVENT
|
176 |
+
|
177 |
+
Version 2 adds the variance of the entropy map in the computation of the loss
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, version=1, lambda_var=0.1):
|
181 |
+
super().__init__()
|
182 |
+
self.version = version
|
183 |
+
self.lambda_var = lambda_var
|
184 |
+
|
185 |
+
def __call__(self, pred):
|
186 |
+
assert pred.dim() == 4
|
187 |
+
n, c, h, w = pred.size()
|
188 |
+
entropy_map = -torch.mul(pred, torch.log2(pred + 1e-30)) / np.log2(c)
|
189 |
+
if self.version == 1:
|
190 |
+
return torch.sum(entropy_map) / (n * h * w)
|
191 |
+
else:
|
192 |
+
entropy_map_demean = entropy_map - torch.sum(entropy_map) / (n * h * w)
|
193 |
+
entropy_map_squ = torch.mul(entropy_map_demean, entropy_map_demean)
|
194 |
+
return torch.sum(entropy_map + self.lambda_var * entropy_map_squ) / (
|
195 |
+
n * h * w
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
class MSELoss(nn.Module):
|
200 |
+
"""
|
201 |
+
Creates a criterion that measures the mean squared error
|
202 |
+
(squared L2 norm) between each element in the input x and target y .
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self):
|
206 |
+
super().__init__()
|
207 |
+
self.loss = nn.MSELoss()
|
208 |
+
|
209 |
+
def __call__(self, prediction, target):
|
210 |
+
return self.loss(prediction, target.to(prediction.device))
|
211 |
+
|
212 |
+
|
213 |
+
class L1Loss(MSELoss):
|
214 |
+
"""
|
215 |
+
Creates a criterion that measures the mean absolute error
|
216 |
+
(MAE) between each element in the input x and target y
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self):
|
220 |
+
super().__init__()
|
221 |
+
self.loss = nn.L1Loss()
|
222 |
+
|
223 |
+
|
224 |
+
class SIMSELoss(nn.Module):
|
225 |
+
"""Scale invariant MSE Loss"""
|
226 |
+
|
227 |
+
def __init__(self):
|
228 |
+
super(SIMSELoss, self).__init__()
|
229 |
+
|
230 |
+
def __call__(self, prediction, target):
|
231 |
+
d = prediction - target
|
232 |
+
diff = torch.mean(d * d)
|
233 |
+
relDiff = torch.mean(d) * torch.mean(d)
|
234 |
+
return diff - relDiff
|
235 |
+
|
236 |
+
|
237 |
+
class SIGMLoss(nn.Module):
|
238 |
+
"""loss from MiDaS paper
|
239 |
+
MiDaS did not specify how the gradients were computed but we use Sobel
|
240 |
+
filters which approximate the derivative of an image.
|
241 |
+
"""
|
242 |
+
|
243 |
+
def __init__(self, gmweight=0.5, scale=4, device="cuda"):
|
244 |
+
super(SIGMLoss, self).__init__()
|
245 |
+
self.gmweight = gmweight
|
246 |
+
self.sobelx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).to(device)
|
247 |
+
self.sobely = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).to(device)
|
248 |
+
self.scale = scale
|
249 |
+
|
250 |
+
def __call__(self, prediction, target):
|
251 |
+
# get disparities
|
252 |
+
# align both the prediction and the ground truth to have zero
|
253 |
+
# translation and unit scale
|
254 |
+
t_pred = torch.median(prediction)
|
255 |
+
t_targ = torch.median(target)
|
256 |
+
s_pred = torch.mean(torch.abs(prediction - t_pred))
|
257 |
+
s_targ = torch.mean(torch.abs(target - t_targ))
|
258 |
+
pred = (prediction - t_pred) / s_pred
|
259 |
+
targ = (target - t_targ) / s_targ
|
260 |
+
|
261 |
+
R = pred - targ
|
262 |
+
|
263 |
+
# get gradient map with sobel filters
|
264 |
+
batch_size = prediction.size()[0]
|
265 |
+
num_pix = prediction.size()[-1] * prediction.size()[-2]
|
266 |
+
sobelx = (self.sobelx).expand((batch_size, 1, -1, -1))
|
267 |
+
sobely = (self.sobely).expand((batch_size, 1, -1, -1))
|
268 |
+
gmLoss = 0 # gradient matching term
|
269 |
+
for k in range(self.scale):
|
270 |
+
R_ = F.interpolate(R, scale_factor=1 / 2 ** k)
|
271 |
+
Rx = F.conv2d(R_, sobelx, stride=1)
|
272 |
+
Ry = F.conv2d(R_, sobely, stride=1)
|
273 |
+
gmLoss += torch.sum(torch.abs(Rx) + torch.abs(Ry))
|
274 |
+
gmLoss = self.gmweight / num_pix * gmLoss
|
275 |
+
# scale invariant MSE
|
276 |
+
simseLoss = 0.5 / num_pix * torch.sum(torch.abs(R))
|
277 |
+
loss = simseLoss + gmLoss
|
278 |
+
return loss
|
279 |
+
|
280 |
+
|
281 |
+
class ContextLoss(nn.Module):
|
282 |
+
"""
|
283 |
+
Masked L1 loss on non-water
|
284 |
+
"""
|
285 |
+
|
286 |
+
def __call__(self, input, target, mask):
|
287 |
+
return torch.mean(torch.abs(torch.mul((input - target), 1 - mask)))
|
288 |
+
|
289 |
+
|
290 |
+
class ReconstructionLoss(nn.Module):
|
291 |
+
"""
|
292 |
+
Masked L1 loss on water
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __call__(self, input, target, mask):
|
296 |
+
return torch.mean(torch.abs(torch.mul((input - target), mask)))
|
297 |
+
|
298 |
+
|
299 |
+
##################################################################################
|
300 |
+
# VGG network definition
|
301 |
+
##################################################################################
|
302 |
+
|
303 |
+
# Source: https://github.com/NVIDIA/pix2pixHD
|
304 |
+
class Vgg19(nn.Module):
|
305 |
+
def __init__(self, requires_grad=False):
|
306 |
+
super(Vgg19, self).__init__()
|
307 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
308 |
+
self.slice1 = nn.Sequential()
|
309 |
+
self.slice2 = nn.Sequential()
|
310 |
+
self.slice3 = nn.Sequential()
|
311 |
+
self.slice4 = nn.Sequential()
|
312 |
+
self.slice5 = nn.Sequential()
|
313 |
+
for x in range(2):
|
314 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
315 |
+
for x in range(2, 7):
|
316 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
317 |
+
for x in range(7, 12):
|
318 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
319 |
+
for x in range(12, 21):
|
320 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
321 |
+
for x in range(21, 30):
|
322 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
323 |
+
if not requires_grad:
|
324 |
+
for param in self.parameters():
|
325 |
+
param.requires_grad = False
|
326 |
+
|
327 |
+
def forward(self, X):
|
328 |
+
h_relu1 = self.slice1(X)
|
329 |
+
h_relu2 = self.slice2(h_relu1)
|
330 |
+
h_relu3 = self.slice3(h_relu2)
|
331 |
+
h_relu4 = self.slice4(h_relu3)
|
332 |
+
h_relu5 = self.slice5(h_relu4)
|
333 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
334 |
+
return out
|
335 |
+
|
336 |
+
|
337 |
+
# Source: https://github.com/NVIDIA/pix2pixHD
|
338 |
+
class VGGLoss(nn.Module):
|
339 |
+
def __init__(self, device):
|
340 |
+
super().__init__()
|
341 |
+
self.vgg = Vgg19().to(device).eval()
|
342 |
+
self.criterion = nn.L1Loss()
|
343 |
+
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
|
344 |
+
|
345 |
+
def forward(self, x, y):
|
346 |
+
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
347 |
+
loss = 0
|
348 |
+
for i in range(len(x_vgg)):
|
349 |
+
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
350 |
+
return loss
|
351 |
+
|
352 |
+
|
353 |
+
def get_losses(opts, verbose, device=None):
|
354 |
+
"""Sets the loss functions to be used by G, D and C, as specified
|
355 |
+
in the opts and returns a dictionnary of losses:
|
356 |
+
|
357 |
+
losses = {
|
358 |
+
"G": {
|
359 |
+
"gan": {"a": ..., "t": ...},
|
360 |
+
"cycle": {"a": ..., "t": ...}
|
361 |
+
"auto": {"a": ..., "t": ...}
|
362 |
+
"tasks": {"h": ..., "d": ..., "s": ..., etc.}
|
363 |
+
},
|
364 |
+
"D": GANLoss,
|
365 |
+
"C": ...
|
366 |
+
}
|
367 |
+
"""
|
368 |
+
|
369 |
+
losses = {
|
370 |
+
"G": {"a": {}, "p": {}, "tasks": {}},
|
371 |
+
"D": {"default": {}, "advent": {}},
|
372 |
+
"C": {},
|
373 |
+
}
|
374 |
+
|
375 |
+
# ------------------------------
|
376 |
+
# ----- Generator Losses -----
|
377 |
+
# ------------------------------
|
378 |
+
|
379 |
+
# painter losses
|
380 |
+
if "p" in opts.tasks:
|
381 |
+
losses["G"]["p"]["gan"] = (
|
382 |
+
HingeLoss()
|
383 |
+
if opts.gen.p.loss == "hinge"
|
384 |
+
else GANLoss(
|
385 |
+
use_lsgan=False,
|
386 |
+
soft_shift=opts.dis.soft_shift,
|
387 |
+
flip_prob=opts.dis.flip_prob,
|
388 |
+
)
|
389 |
+
)
|
390 |
+
losses["G"]["p"]["dm"] = MSELoss()
|
391 |
+
losses["G"]["p"]["vgg"] = VGGLoss(device)
|
392 |
+
losses["G"]["p"]["tv"] = TVLoss()
|
393 |
+
losses["G"]["p"]["context"] = ContextLoss()
|
394 |
+
losses["G"]["p"]["reconstruction"] = ReconstructionLoss()
|
395 |
+
losses["G"]["p"]["featmatch"] = FeatMatchLoss()
|
396 |
+
|
397 |
+
# depth losses
|
398 |
+
if "d" in opts.tasks:
|
399 |
+
if not opts.gen.d.classify.enable:
|
400 |
+
if opts.gen.d.loss == "dada":
|
401 |
+
depth_func = DADADepthLoss()
|
402 |
+
else:
|
403 |
+
depth_func = SIGMLoss(opts.train.lambdas.G.d.gml)
|
404 |
+
else:
|
405 |
+
depth_func = CrossEntropy()
|
406 |
+
|
407 |
+
losses["G"]["tasks"]["d"] = depth_func
|
408 |
+
|
409 |
+
# segmentation losses
|
410 |
+
if "s" in opts.tasks:
|
411 |
+
losses["G"]["tasks"]["s"] = {}
|
412 |
+
losses["G"]["tasks"]["s"]["crossent"] = CrossEntropy()
|
413 |
+
losses["G"]["tasks"]["s"]["minent"] = MinentLoss()
|
414 |
+
losses["G"]["tasks"]["s"]["advent"] = ADVENTAdversarialLoss(
|
415 |
+
opts, gan_type=opts.dis.s.gan_type
|
416 |
+
)
|
417 |
+
|
418 |
+
# masker losses
|
419 |
+
if "m" in opts.tasks:
|
420 |
+
losses["G"]["tasks"]["m"] = {}
|
421 |
+
losses["G"]["tasks"]["m"]["bce"] = nn.BCEWithLogitsLoss()
|
422 |
+
if opts.gen.m.use_minent_var:
|
423 |
+
losses["G"]["tasks"]["m"]["minent"] = MinentLoss(
|
424 |
+
version=2, lambda_var=opts.train.lambdas.advent.ent_var
|
425 |
+
)
|
426 |
+
else:
|
427 |
+
losses["G"]["tasks"]["m"]["minent"] = MinentLoss()
|
428 |
+
losses["G"]["tasks"]["m"]["tv"] = TVLoss()
|
429 |
+
losses["G"]["tasks"]["m"]["advent"] = ADVENTAdversarialLoss(
|
430 |
+
opts, gan_type=opts.dis.m.gan_type
|
431 |
+
)
|
432 |
+
losses["G"]["tasks"]["m"]["gi"] = GroundIntersectionLoss()
|
433 |
+
|
434 |
+
# ----------------------------------
|
435 |
+
# ----- Discriminator Losses -----
|
436 |
+
# ----------------------------------
|
437 |
+
if "p" in opts.tasks:
|
438 |
+
losses["D"]["p"] = losses["G"]["p"]["gan"]
|
439 |
+
if "m" in opts.tasks or "s" in opts.tasks:
|
440 |
+
losses["D"]["advent"] = ADVENTAdversarialLoss(opts)
|
441 |
+
return losses
|
442 |
+
|
443 |
+
|
444 |
+
class GroundIntersectionLoss(nn.Module):
|
445 |
+
"""
|
446 |
+
Penalize areas in ground seg but not in flood mask
|
447 |
+
"""
|
448 |
+
|
449 |
+
def __call__(self, pred, pseudo_ground):
|
450 |
+
return torch.mean(1.0 * ((pseudo_ground - pred) > 0.5))
|
451 |
+
|
452 |
+
|
453 |
+
def prob_2_entropy(prob):
|
454 |
+
"""
|
455 |
+
convert probabilistic prediction maps to weighted self-information maps
|
456 |
+
"""
|
457 |
+
n, c, h, w = prob.size()
|
458 |
+
return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c)
|
459 |
+
|
460 |
+
|
461 |
+
class CustomBCELoss(nn.Module):
|
462 |
+
"""
|
463 |
+
The first argument is a tensor and the second argument is an int.
|
464 |
+
There is no need to take sigmoid before calling this function.
|
465 |
+
"""
|
466 |
+
|
467 |
+
def __init__(self):
|
468 |
+
super().__init__()
|
469 |
+
self.loss = nn.BCEWithLogitsLoss()
|
470 |
+
|
471 |
+
def __call__(self, prediction, target):
|
472 |
+
return self.loss(
|
473 |
+
prediction,
|
474 |
+
torch.FloatTensor(prediction.size())
|
475 |
+
.fill_(target)
|
476 |
+
.to(prediction.get_device()),
|
477 |
+
)
|
478 |
+
|
479 |
+
|
480 |
+
class ADVENTAdversarialLoss(nn.Module):
|
481 |
+
"""
|
482 |
+
The class is for calculating the advent loss.
|
483 |
+
It is used to indirectly shrink the domain gap between sim and real
|
484 |
+
|
485 |
+
_call_ function:
|
486 |
+
prediction: torch.tensor with shape of [bs,c,h,w]
|
487 |
+
target: int; domain label: 0 (sim) or 1 (real)
|
488 |
+
discriminator: the discriminator model tells if a tensor is from sim or real
|
489 |
+
|
490 |
+
output: the loss value of GANLoss
|
491 |
+
"""
|
492 |
+
|
493 |
+
def __init__(self, opts, gan_type="GAN"):
|
494 |
+
super().__init__()
|
495 |
+
self.opts = opts
|
496 |
+
if gan_type == "GAN":
|
497 |
+
self.loss = CustomBCELoss()
|
498 |
+
elif gan_type == "WGAN" or "WGAN_gp" or "WGAN_norm":
|
499 |
+
self.loss = lambda x, y: -torch.mean(y * x + (1 - y) * (1 - x))
|
500 |
+
else:
|
501 |
+
raise NotImplementedError
|
502 |
+
|
503 |
+
def __call__(self, prediction, target, discriminator, depth_preds=None):
|
504 |
+
"""
|
505 |
+
Compute the GAN loss from the Advent Discriminator given
|
506 |
+
normalized (softmaxed) predictions (=pixel-wise class probabilities),
|
507 |
+
and int labels (target).
|
508 |
+
|
509 |
+
Args:
|
510 |
+
prediction (torch.Tensor): pixel-wise probability distribution over classes
|
511 |
+
target (torch.Tensor): pixel wise int target labels
|
512 |
+
discriminator (torch.nn.Module): Discriminator to get the loss
|
513 |
+
|
514 |
+
Returns:
|
515 |
+
torch.Tensor: float 0-D loss
|
516 |
+
"""
|
517 |
+
d_out = prob_2_entropy(prediction)
|
518 |
+
if depth_preds is not None:
|
519 |
+
d_out = d_out * depth_preds
|
520 |
+
d_out = discriminator(d_out)
|
521 |
+
if self.opts.dis.m.architecture == "OmniDiscriminator":
|
522 |
+
d_out = multiDiscriminatorAdapter(d_out, self.opts)
|
523 |
+
loss_ = self.loss(d_out, target)
|
524 |
+
return loss_
|
525 |
+
|
526 |
+
|
527 |
+
def multiDiscriminatorAdapter(d_out: list, opts: dict) -> torch.tensor:
|
528 |
+
"""
|
529 |
+
Because the OmniDiscriminator does not directly return a tensor
|
530 |
+
(but a list of tensor).
|
531 |
+
Since there is no multilevel masker, the 0th tensor in the list is all we want.
|
532 |
+
This Adapter returns the first element(tensor) of the list that OmniDiscriminator
|
533 |
+
returns.
|
534 |
+
"""
|
535 |
+
if (
|
536 |
+
isinstance(d_out, list) and len(d_out) == 1
|
537 |
+
): # adapt the multi-scale OmniDiscriminator
|
538 |
+
if not opts.dis.p.get_intermediate_features:
|
539 |
+
d_out = d_out[0][0]
|
540 |
+
else:
|
541 |
+
d_out = d_out[0]
|
542 |
+
else:
|
543 |
+
raise Exception(
|
544 |
+
"Check the setting of OmniDiscriminator! "
|
545 |
+
+ "For now, we don't support multi-scale OmniDiscriminator."
|
546 |
+
)
|
547 |
+
return d_out
|
548 |
+
|
549 |
+
|
550 |
+
class HingeLoss(nn.Module):
|
551 |
+
"""
|
552 |
+
Adapted from https://github.com/NVlabs/SPADE/blob/master/models/networks/loss.py
|
553 |
+
for the painter
|
554 |
+
"""
|
555 |
+
|
556 |
+
def __init__(self, tensor=torch.FloatTensor):
|
557 |
+
super().__init__()
|
558 |
+
self.zero_tensor = None
|
559 |
+
self.Tensor = tensor
|
560 |
+
|
561 |
+
def get_zero_tensor(self, input):
|
562 |
+
if self.zero_tensor is None:
|
563 |
+
self.zero_tensor = self.Tensor(1).fill_(0)
|
564 |
+
self.zero_tensor.requires_grad_(False)
|
565 |
+
self.zero_tensor = self.zero_tensor.to(input.device)
|
566 |
+
return self.zero_tensor.expand_as(input)
|
567 |
+
|
568 |
+
def loss(self, input, target_is_real, for_discriminator=True):
|
569 |
+
if for_discriminator:
|
570 |
+
if target_is_real:
|
571 |
+
minval = torch.min(input - 1, self.get_zero_tensor(input))
|
572 |
+
loss = -torch.mean(minval)
|
573 |
+
else:
|
574 |
+
minval = torch.min(-input - 1, self.get_zero_tensor(input))
|
575 |
+
loss = -torch.mean(minval)
|
576 |
+
else:
|
577 |
+
assert target_is_real, "The generator's hinge loss must be aiming for real"
|
578 |
+
loss = -torch.mean(input)
|
579 |
+
return loss
|
580 |
+
|
581 |
+
def __call__(self, input, target_is_real, for_discriminator=True):
|
582 |
+
# computing loss is a bit complicated because |input| may not be
|
583 |
+
# a tensor, but list of tensors in case of multiscale discriminator
|
584 |
+
if isinstance(input, list):
|
585 |
+
loss = 0
|
586 |
+
for pred_i in input:
|
587 |
+
if isinstance(pred_i, list):
|
588 |
+
pred_i = pred_i[-1]
|
589 |
+
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
|
590 |
+
loss += loss_tensor
|
591 |
+
return loss / len(input)
|
592 |
+
else:
|
593 |
+
return self.loss(input, target_is_real, for_discriminator)
|
594 |
+
|
595 |
+
|
596 |
+
class DADADepthLoss:
|
597 |
+
"""Defines the reverse Huber loss from DADA paper for depth prediction
|
598 |
+
- Samples with larger residuals are penalized more by l2 term
|
599 |
+
- Samples with smaller residuals are penalized more by l1 term
|
600 |
+
From https://github.com/valeoai/DADA/blob/master/dada/utils/func.py
|
601 |
+
"""
|
602 |
+
|
603 |
+
def loss_calc_depth(self, pred, label):
|
604 |
+
n, c, h, w = pred.size()
|
605 |
+
assert c == 1
|
606 |
+
|
607 |
+
pred = pred.squeeze()
|
608 |
+
label = label.squeeze()
|
609 |
+
|
610 |
+
adiff = torch.abs(pred - label)
|
611 |
+
batch_max = 0.2 * torch.max(adiff).item()
|
612 |
+
t1_mask = adiff.le(batch_max).float()
|
613 |
+
t2_mask = adiff.gt(batch_max).float()
|
614 |
+
t1 = adiff * t1_mask
|
615 |
+
t2 = (adiff * adiff + batch_max * batch_max) / (2 * batch_max)
|
616 |
+
t2 = t2 * t2_mask
|
617 |
+
return (torch.sum(t1) + torch.sum(t2)) / torch.numel(pred.data)
|
618 |
+
|
619 |
+
def __call__(self, pred, label):
|
620 |
+
return self.loss_calc_depth(pred, label)
|
climategan/masker.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from climategan.blocks import (
|
6 |
+
BaseDecoder,
|
7 |
+
Conv2dBlock,
|
8 |
+
InterpolateNearest2d,
|
9 |
+
SPADEResnetBlock,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def create_mask_decoder(opts, no_init=False, verbose=0):
|
14 |
+
if opts.gen.m.use_spade:
|
15 |
+
if verbose > 0:
|
16 |
+
print(" - Add Spade Mask Decoder")
|
17 |
+
assert "d" in opts.tasks or "s" in opts.tasks
|
18 |
+
return MaskSpadeDecoder(opts)
|
19 |
+
else:
|
20 |
+
if verbose > 0:
|
21 |
+
print(" - Add Base Mask Decoder")
|
22 |
+
return MaskBaseDecoder(opts)
|
23 |
+
|
24 |
+
|
25 |
+
class MaskBaseDecoder(BaseDecoder):
|
26 |
+
def __init__(self, opts):
|
27 |
+
low_level_feats_dim = -1
|
28 |
+
use_v3 = opts.gen.encoder.architecture == "deeplabv3"
|
29 |
+
use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet"
|
30 |
+
use_low = opts.gen.m.use_low_level_feats
|
31 |
+
use_dada = ("d" in opts.tasks) and opts.gen.m.use_dada
|
32 |
+
|
33 |
+
if use_v3 and use_mobile_net:
|
34 |
+
input_dim = 320
|
35 |
+
if use_low:
|
36 |
+
low_level_feats_dim = 24
|
37 |
+
elif use_v3:
|
38 |
+
input_dim = 2048
|
39 |
+
if use_low:
|
40 |
+
low_level_feats_dim = 256
|
41 |
+
else:
|
42 |
+
input_dim = 2048
|
43 |
+
|
44 |
+
super().__init__(
|
45 |
+
n_upsample=opts.gen.m.n_upsample,
|
46 |
+
n_res=opts.gen.m.n_res,
|
47 |
+
input_dim=input_dim,
|
48 |
+
proj_dim=opts.gen.m.proj_dim,
|
49 |
+
output_dim=opts.gen.m.output_dim,
|
50 |
+
norm=opts.gen.m.norm,
|
51 |
+
activ=opts.gen.m.activ,
|
52 |
+
pad_type=opts.gen.m.pad_type,
|
53 |
+
output_activ="none",
|
54 |
+
low_level_feats_dim=low_level_feats_dim,
|
55 |
+
use_dada=use_dada,
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
class MaskSpadeDecoder(nn.Module):
|
60 |
+
def __init__(self, opts):
|
61 |
+
"""Create a SPADE-based decoder, which forwards z and the conditioning
|
62 |
+
tensors seg (in the original paper, conditioning is on a semantic map only).
|
63 |
+
All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink
|
64 |
+
the channel dimension, and an upsampling is applied after each. Therefore
|
65 |
+
2 upsamplings at this point. Then, for each remaining upsamplings
|
66 |
+
(w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3
|
67 |
+
channels, the number of channels is therefore:
|
68 |
+
final_nc = channels(z) * 2 ** (spade_n_up - 2)
|
69 |
+
Args:
|
70 |
+
latent_dim (tuple): z's shape (only the number of channels matters)
|
71 |
+
cond_nc (int): conditioning tensor's expected number of channels
|
72 |
+
spade_n_up (int): Number of total upsamplings from z
|
73 |
+
spade_use_spectral_norm (bool): use spectral normalization?
|
74 |
+
spade_param_free_norm (str): norm to use before SPADE de-normalization
|
75 |
+
spade_kernel_size (int): SPADE conv layers' kernel size
|
76 |
+
Returns:
|
77 |
+
[type]: [description]
|
78 |
+
"""
|
79 |
+
super().__init__()
|
80 |
+
self.opts = opts
|
81 |
+
latent_dim = opts.gen.m.spade.latent_dim
|
82 |
+
cond_nc = opts.gen.m.spade.cond_nc
|
83 |
+
spade_use_spectral_norm = opts.gen.m.spade.spade_use_spectral_norm
|
84 |
+
spade_param_free_norm = opts.gen.m.spade.spade_param_free_norm
|
85 |
+
if self.opts.gen.m.spade.activations.all_lrelu:
|
86 |
+
spade_activation = "lrelu"
|
87 |
+
else:
|
88 |
+
spade_activation = None
|
89 |
+
spade_kernel_size = 3
|
90 |
+
self.num_layers = opts.gen.m.spade.num_layers
|
91 |
+
self.z_nc = latent_dim
|
92 |
+
|
93 |
+
if (
|
94 |
+
opts.gen.encoder.architecture == "deeplabv3"
|
95 |
+
and opts.gen.deeplabv3.backbone == "mobilenet"
|
96 |
+
):
|
97 |
+
self.input_dim = [320, 24]
|
98 |
+
self.low_level_conv = Conv2dBlock(
|
99 |
+
self.input_dim[1],
|
100 |
+
self.input_dim[0],
|
101 |
+
3,
|
102 |
+
padding=1,
|
103 |
+
activation="lrelu",
|
104 |
+
pad_type="reflect",
|
105 |
+
norm="spectral_batch",
|
106 |
+
)
|
107 |
+
self.merge_feats_conv = Conv2dBlock(
|
108 |
+
self.input_dim[0] * 2,
|
109 |
+
self.z_nc,
|
110 |
+
3,
|
111 |
+
padding=1,
|
112 |
+
activation="lrelu",
|
113 |
+
pad_type="reflect",
|
114 |
+
norm="spectral_batch",
|
115 |
+
)
|
116 |
+
elif (
|
117 |
+
opts.gen.encoder.architecture == "deeplabv3"
|
118 |
+
and opts.gen.deeplabv3.backbone == "resnet"
|
119 |
+
):
|
120 |
+
self.input_dim = [2048, 256]
|
121 |
+
if self.opts.gen.m.use_proj:
|
122 |
+
proj_dim = self.opts.gen.m.proj_dim
|
123 |
+
self.low_level_conv = Conv2dBlock(
|
124 |
+
self.input_dim[1],
|
125 |
+
proj_dim,
|
126 |
+
3,
|
127 |
+
padding=1,
|
128 |
+
activation="lrelu",
|
129 |
+
pad_type="reflect",
|
130 |
+
norm="spectral_batch",
|
131 |
+
)
|
132 |
+
self.high_level_conv = Conv2dBlock(
|
133 |
+
self.input_dim[0],
|
134 |
+
proj_dim,
|
135 |
+
3,
|
136 |
+
padding=1,
|
137 |
+
activation="lrelu",
|
138 |
+
pad_type="reflect",
|
139 |
+
norm="spectral_batch",
|
140 |
+
)
|
141 |
+
self.merge_feats_conv = Conv2dBlock(
|
142 |
+
proj_dim * 2,
|
143 |
+
self.z_nc,
|
144 |
+
3,
|
145 |
+
padding=1,
|
146 |
+
activation="lrelu",
|
147 |
+
pad_type="reflect",
|
148 |
+
norm="spectral_batch",
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
self.low_level_conv = Conv2dBlock(
|
152 |
+
self.input_dim[1],
|
153 |
+
self.input_dim[0],
|
154 |
+
3,
|
155 |
+
padding=1,
|
156 |
+
activation="lrelu",
|
157 |
+
pad_type="reflect",
|
158 |
+
norm="spectral_batch",
|
159 |
+
)
|
160 |
+
self.merge_feats_conv = Conv2dBlock(
|
161 |
+
self.input_dim[0] * 2,
|
162 |
+
self.z_nc,
|
163 |
+
3,
|
164 |
+
padding=1,
|
165 |
+
activation="lrelu",
|
166 |
+
pad_type="reflect",
|
167 |
+
norm="spectral_batch",
|
168 |
+
)
|
169 |
+
|
170 |
+
elif opts.gen.encoder.architecture == "deeplabv2":
|
171 |
+
self.input_dim = 2048
|
172 |
+
self.fc_conv = Conv2dBlock(
|
173 |
+
self.input_dim,
|
174 |
+
self.z_nc,
|
175 |
+
3,
|
176 |
+
padding=1,
|
177 |
+
activation="lrelu",
|
178 |
+
pad_type="reflect",
|
179 |
+
norm="spectral_batch",
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
raise ValueError("Unknown encoder type")
|
183 |
+
|
184 |
+
self.spade_blocks = []
|
185 |
+
|
186 |
+
for i in range(self.num_layers):
|
187 |
+
self.spade_blocks.append(
|
188 |
+
SPADEResnetBlock(
|
189 |
+
int(self.z_nc / (2 ** i)),
|
190 |
+
int(self.z_nc / (2 ** (i + 1))),
|
191 |
+
cond_nc,
|
192 |
+
spade_use_spectral_norm,
|
193 |
+
spade_param_free_norm,
|
194 |
+
spade_kernel_size,
|
195 |
+
spade_activation,
|
196 |
+
).cuda()
|
197 |
+
)
|
198 |
+
self.spade_blocks = nn.Sequential(*self.spade_blocks)
|
199 |
+
|
200 |
+
self.final_nc = int(self.z_nc / (2 ** self.num_layers))
|
201 |
+
self.mask_conv = Conv2dBlock(
|
202 |
+
self.final_nc,
|
203 |
+
1,
|
204 |
+
3,
|
205 |
+
padding=1,
|
206 |
+
activation="none",
|
207 |
+
pad_type="reflect",
|
208 |
+
norm="spectral",
|
209 |
+
)
|
210 |
+
self.upsample = InterpolateNearest2d(scale_factor=2)
|
211 |
+
|
212 |
+
def forward(self, z, cond, z_depth=None):
|
213 |
+
if isinstance(z, (list, tuple)):
|
214 |
+
z_h, z_l = z
|
215 |
+
if self.opts.gen.m.use_proj:
|
216 |
+
z_l = self.low_level_conv(z_l)
|
217 |
+
z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear")
|
218 |
+
z_h = self.high_level_conv(z_h)
|
219 |
+
else:
|
220 |
+
z_l = self.low_level_conv(z_l)
|
221 |
+
z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear")
|
222 |
+
z = torch.cat([z_h, z_l], axis=1)
|
223 |
+
y = self.merge_feats_conv(z)
|
224 |
+
else:
|
225 |
+
y = self.fc_conv(z)
|
226 |
+
|
227 |
+
for i in range(self.num_layers):
|
228 |
+
y = self.spade_blocks[i](y, cond)
|
229 |
+
y = self.upsample(y)
|
230 |
+
y = self.mask_conv(y)
|
231 |
+
return y
|
232 |
+
|
233 |
+
def __str__(self):
|
234 |
+
return "MaskerSpadeDecoder"
|
climategan/norms.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Normalization layers used in blocks
|
2 |
+
"""
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class AdaptiveInstanceNorm2d(nn.Module):
|
9 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
10 |
+
super(AdaptiveInstanceNorm2d, self).__init__()
|
11 |
+
self.num_features = num_features
|
12 |
+
self.eps = eps
|
13 |
+
self.momentum = momentum
|
14 |
+
# weight and bias are dynamically assigned
|
15 |
+
self.weight = None
|
16 |
+
self.bias = None
|
17 |
+
# just dummy buffers, not used
|
18 |
+
self.register_buffer("running_mean", torch.zeros(num_features))
|
19 |
+
self.register_buffer("running_var", torch.ones(num_features))
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
assert (
|
23 |
+
self.weight is not None and self.bias is not None
|
24 |
+
), "Please assign weight and bias before calling AdaIN!"
|
25 |
+
b, c = x.size(0), x.size(1)
|
26 |
+
running_mean = self.running_mean.repeat(b)
|
27 |
+
running_var = self.running_var.repeat(b)
|
28 |
+
|
29 |
+
# Apply instance norm
|
30 |
+
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
|
31 |
+
|
32 |
+
out = F.batch_norm(
|
33 |
+
x_reshaped,
|
34 |
+
running_mean,
|
35 |
+
running_var,
|
36 |
+
self.weight,
|
37 |
+
self.bias,
|
38 |
+
True,
|
39 |
+
self.momentum,
|
40 |
+
self.eps,
|
41 |
+
)
|
42 |
+
|
43 |
+
return out.view(b, c, *x.size()[2:])
|
44 |
+
|
45 |
+
def __repr__(self):
|
46 |
+
return self.__class__.__name__ + "(" + str(self.num_features) + ")"
|
47 |
+
|
48 |
+
|
49 |
+
class LayerNorm(nn.Module):
|
50 |
+
def __init__(self, num_features, eps=1e-5, affine=True):
|
51 |
+
super(LayerNorm, self).__init__()
|
52 |
+
self.num_features = num_features
|
53 |
+
self.affine = affine
|
54 |
+
self.eps = eps
|
55 |
+
|
56 |
+
if self.affine:
|
57 |
+
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
|
58 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
shape = [-1] + [1] * (x.dim() - 1)
|
62 |
+
# print(x.size())
|
63 |
+
if x.size(0) == 1:
|
64 |
+
# These two lines run much faster in pytorch 0.4
|
65 |
+
# than the two lines listed below.
|
66 |
+
mean = x.view(-1).mean().view(*shape)
|
67 |
+
std = x.view(-1).std().view(*shape)
|
68 |
+
else:
|
69 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
70 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
71 |
+
|
72 |
+
x = (x - mean) / (std + self.eps)
|
73 |
+
|
74 |
+
if self.affine:
|
75 |
+
shape = [1, -1] + [1] * (x.dim() - 2)
|
76 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
def l2normalize(v, eps=1e-12):
|
81 |
+
return v / (v.norm() + eps)
|
82 |
+
|
83 |
+
|
84 |
+
class SpectralNorm(nn.Module):
|
85 |
+
"""
|
86 |
+
Based on the paper "Spectral Normalization for Generative Adversarial Networks"
|
87 |
+
by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida and the
|
88 |
+
Pytorch implementation:
|
89 |
+
https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, module, name="weight", power_iterations=1):
|
93 |
+
super().__init__()
|
94 |
+
self.module = module
|
95 |
+
self.name = name
|
96 |
+
self.power_iterations = power_iterations
|
97 |
+
if not self._made_params():
|
98 |
+
self._make_params()
|
99 |
+
|
100 |
+
def _update_u_v(self):
|
101 |
+
u = getattr(self.module, self.name + "_u")
|
102 |
+
v = getattr(self.module, self.name + "_v")
|
103 |
+
w = getattr(self.module, self.name + "_bar")
|
104 |
+
|
105 |
+
height = w.data.shape[0]
|
106 |
+
for _ in range(self.power_iterations):
|
107 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
|
108 |
+
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
109 |
+
|
110 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
111 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
112 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
113 |
+
|
114 |
+
def _made_params(self):
|
115 |
+
try:
|
116 |
+
u = getattr(self.module, self.name + "_u") # noqa: F841
|
117 |
+
v = getattr(self.module, self.name + "_v") # noqa: F841
|
118 |
+
w = getattr(self.module, self.name + "_bar") # noqa: F841
|
119 |
+
return True
|
120 |
+
except AttributeError:
|
121 |
+
return False
|
122 |
+
|
123 |
+
def _make_params(self):
|
124 |
+
w = getattr(self.module, self.name)
|
125 |
+
|
126 |
+
height = w.data.shape[0]
|
127 |
+
width = w.view(height, -1).data.shape[1]
|
128 |
+
|
129 |
+
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
130 |
+
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
131 |
+
u.data = l2normalize(u.data)
|
132 |
+
v.data = l2normalize(v.data)
|
133 |
+
w_bar = nn.Parameter(w.data)
|
134 |
+
|
135 |
+
del self.module._parameters[self.name]
|
136 |
+
|
137 |
+
self.module.register_parameter(self.name + "_u", u)
|
138 |
+
self.module.register_parameter(self.name + "_v", v)
|
139 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
140 |
+
|
141 |
+
def forward(self, *args):
|
142 |
+
self._update_u_v()
|
143 |
+
return self.module.forward(*args)
|
144 |
+
|
145 |
+
|
146 |
+
class SPADE(nn.Module):
|
147 |
+
def __init__(self, param_free_norm_type, kernel_size, norm_nc, cond_nc):
|
148 |
+
super().__init__()
|
149 |
+
|
150 |
+
if param_free_norm_type == "instance":
|
151 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
152 |
+
# elif param_free_norm_type == "syncbatch":
|
153 |
+
# self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
|
154 |
+
elif param_free_norm_type == "batch":
|
155 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
156 |
+
else:
|
157 |
+
raise ValueError(
|
158 |
+
"%s is not a recognized param-free norm type in SPADE"
|
159 |
+
% param_free_norm_type
|
160 |
+
)
|
161 |
+
|
162 |
+
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
163 |
+
nhidden = 128
|
164 |
+
|
165 |
+
pw = kernel_size // 2
|
166 |
+
self.mlp_shared = nn.Sequential(
|
167 |
+
nn.Conv2d(cond_nc, nhidden, kernel_size=kernel_size, padding=pw), nn.ReLU()
|
168 |
+
)
|
169 |
+
self.mlp_gamma = nn.Conv2d(
|
170 |
+
nhidden, norm_nc, kernel_size=kernel_size, padding=pw
|
171 |
+
)
|
172 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernel_size, padding=pw)
|
173 |
+
|
174 |
+
def forward(self, x, segmap):
|
175 |
+
# Part 1. generate parameter-free normalized activations
|
176 |
+
normalized = self.param_free_norm(x)
|
177 |
+
|
178 |
+
# Part 2. produce scaling and bias conditioned on semantic map
|
179 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
|
180 |
+
actv = self.mlp_shared(segmap)
|
181 |
+
gamma = self.mlp_gamma(actv)
|
182 |
+
beta = self.mlp_beta(actv)
|
183 |
+
# apply scale and bias
|
184 |
+
out = normalized * (1 + gamma) + beta
|
185 |
+
|
186 |
+
return out
|
climategan/optim.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Define ExtraAdam and schedulers
|
2 |
+
"""
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.optim import Adam, Optimizer, RMSprop, lr_scheduler
|
7 |
+
from torch_optimizer import NovoGrad, RAdam
|
8 |
+
|
9 |
+
|
10 |
+
def get_scheduler(optimizer, hyperparameters, iterations=-1):
|
11 |
+
"""Get an optimizer's learning rate scheduler based on opts
|
12 |
+
|
13 |
+
Args:
|
14 |
+
optimizer (torch.Optimizer): optimizer for which to schedule the learning rate
|
15 |
+
hyperparameters (addict.Dict): configuration options
|
16 |
+
iterations (int, optional): The index of last epoch. Defaults to -1.
|
17 |
+
When last_epoch=-1, sets initial lr as lr.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
[type]: [description]
|
21 |
+
"""
|
22 |
+
|
23 |
+
policy = hyperparameters.get("lr_policy")
|
24 |
+
lr_step_size = hyperparameters.get("lr_step_size")
|
25 |
+
lr_gamma = hyperparameters.get("lr_gamma")
|
26 |
+
milestones = hyperparameters.get("lr_milestones")
|
27 |
+
|
28 |
+
if policy is None or policy == "constant":
|
29 |
+
scheduler = None # constant scheduler
|
30 |
+
elif policy == "step":
|
31 |
+
scheduler = lr_scheduler.StepLR(
|
32 |
+
optimizer, step_size=lr_step_size, gamma=lr_gamma, last_epoch=iterations,
|
33 |
+
)
|
34 |
+
elif policy == "multi_step":
|
35 |
+
if isinstance(milestones, (list, tuple)):
|
36 |
+
milestones = milestones
|
37 |
+
elif isinstance(milestones, int):
|
38 |
+
assert "lr_step_size" in hyperparameters
|
39 |
+
if iterations == -1:
|
40 |
+
last_milestone = 1000
|
41 |
+
else:
|
42 |
+
last_milestone = iterations
|
43 |
+
milestones = list(range(milestones, last_milestone, lr_step_size))
|
44 |
+
scheduler = lr_scheduler.MultiStepLR(
|
45 |
+
optimizer, milestones=milestones, gamma=lr_gamma, last_epoch=iterations,
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
return NotImplementedError(
|
49 |
+
"learning rate policy [%s] is not implemented", hyperparameters["lr_policy"]
|
50 |
+
)
|
51 |
+
return scheduler
|
52 |
+
|
53 |
+
|
54 |
+
def get_optimizer(net, opt_conf, tasks=None, is_disc=False, iterations=-1):
|
55 |
+
"""Returns a tuple (optimizer, scheduler) according to opt_conf which
|
56 |
+
should come from the trainer's opts as: trainer.opts.<model>.opt
|
57 |
+
|
58 |
+
Args:
|
59 |
+
net (nn.Module): Network to update
|
60 |
+
opt_conf (addict.Dict): optimizer and scheduler options
|
61 |
+
tasks: list of tasks
|
62 |
+
iterations (int, optional): Last epoch number. Defaults to -1, meaning
|
63 |
+
start with base lr.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Tuple: (torch.Optimizer, torch._LRScheduler)
|
67 |
+
"""
|
68 |
+
opt = scheduler = None
|
69 |
+
lr_names = []
|
70 |
+
if tasks is None:
|
71 |
+
lr_default = opt_conf.lr
|
72 |
+
params = net.parameters()
|
73 |
+
lr_names.append("full")
|
74 |
+
elif isinstance(opt_conf.lr, float): # Use default for all tasks
|
75 |
+
lr_default = opt_conf.lr
|
76 |
+
params = net.parameters()
|
77 |
+
lr_names.append("full")
|
78 |
+
elif len(opt_conf.lr) == 1: # Use default for all tasks
|
79 |
+
lr_default = opt_conf.lr.default
|
80 |
+
params = net.parameters()
|
81 |
+
lr_names.append("full")
|
82 |
+
else:
|
83 |
+
lr_default = opt_conf.lr.default
|
84 |
+
params = list()
|
85 |
+
for task in tasks:
|
86 |
+
lr = opt_conf.lr.get(task, lr_default)
|
87 |
+
parameters = None
|
88 |
+
# Parameters for encoder
|
89 |
+
if not is_disc:
|
90 |
+
if task == "m":
|
91 |
+
parameters = net.encoder.parameters()
|
92 |
+
params.append({"params": parameters, "lr": lr})
|
93 |
+
lr_names.append("encoder")
|
94 |
+
# Parameters for decoders
|
95 |
+
if task == "p":
|
96 |
+
if hasattr(net, "painter"):
|
97 |
+
parameters = net.painter.parameters()
|
98 |
+
lr_names.append("painter")
|
99 |
+
else:
|
100 |
+
parameters = net.decoders[task].parameters()
|
101 |
+
lr_names.append(f"decoder_{task}")
|
102 |
+
else:
|
103 |
+
if task in net:
|
104 |
+
parameters = net[task].parameters()
|
105 |
+
lr_names.append(f"disc_{task}")
|
106 |
+
|
107 |
+
if parameters is not None:
|
108 |
+
params.append({"params": parameters, "lr": lr})
|
109 |
+
|
110 |
+
if opt_conf.optimizer.lower() == "extraadam":
|
111 |
+
opt = ExtraAdam(params, lr=lr_default, betas=(opt_conf.beta1, 0.999))
|
112 |
+
elif opt_conf.optimizer.lower() == "novograd":
|
113 |
+
opt = NovoGrad(
|
114 |
+
params, lr=lr_default, betas=(opt_conf.beta1, 0)
|
115 |
+
) # default for beta2 is 0
|
116 |
+
elif opt_conf.optimizer.lower() == "radam":
|
117 |
+
opt = RAdam(params, lr=lr_default, betas=(opt_conf.beta1, 0.999))
|
118 |
+
elif opt_conf.optimizer.lower() == "rmsprop":
|
119 |
+
opt = RMSprop(params, lr=lr_default)
|
120 |
+
else:
|
121 |
+
opt = Adam(params, lr=lr_default, betas=(opt_conf.beta1, 0.999))
|
122 |
+
scheduler = get_scheduler(opt, opt_conf, iterations)
|
123 |
+
return opt, scheduler, lr_names
|
124 |
+
|
125 |
+
|
126 |
+
"""
|
127 |
+
Extragradient Optimizer
|
128 |
+
|
129 |
+
Mostly copied from the extragrad paper repo.
|
130 |
+
|
131 |
+
MIT License
|
132 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
133 |
+
written by Hugo Berard ([email protected]) while at Facebook.
|
134 |
+
"""
|
135 |
+
|
136 |
+
|
137 |
+
class Extragradient(Optimizer):
|
138 |
+
"""Base class for optimizers with extrapolation step.
|
139 |
+
Arguments:
|
140 |
+
params (iterable): an iterable of :class:`torch.Tensor` s or
|
141 |
+
:class:`dict` s. Specifies what Tensors should be optimized.
|
142 |
+
defaults: (dict): a dict containing default values of optimization
|
143 |
+
options (used when a parameter group doesn't specify them).
|
144 |
+
"""
|
145 |
+
|
146 |
+
def __init__(self, params, defaults):
|
147 |
+
super(Extragradient, self).__init__(params, defaults)
|
148 |
+
self.params_copy = []
|
149 |
+
|
150 |
+
def update(self, p, group):
|
151 |
+
raise NotImplementedError
|
152 |
+
|
153 |
+
def extrapolation(self):
|
154 |
+
"""Performs the extrapolation step and save a copy of the current
|
155 |
+
parameters for the update step.
|
156 |
+
"""
|
157 |
+
# Check if a copy of the parameters was already made.
|
158 |
+
is_empty = len(self.params_copy) == 0
|
159 |
+
for group in self.param_groups:
|
160 |
+
for p in group["params"]:
|
161 |
+
u = self.update(p, group)
|
162 |
+
if is_empty:
|
163 |
+
# Save the current parameters for the update step.
|
164 |
+
# Several extrapolation step can be made before each update but
|
165 |
+
# only the parametersbefore the first extrapolation step are saved.
|
166 |
+
self.params_copy.append(p.data.clone())
|
167 |
+
if u is None:
|
168 |
+
continue
|
169 |
+
# Update the current parameters
|
170 |
+
p.data.add_(u)
|
171 |
+
|
172 |
+
def step(self, closure=None):
|
173 |
+
"""Performs a single optimization step.
|
174 |
+
Arguments:
|
175 |
+
closure (callable, optional): A closure that reevaluates the model
|
176 |
+
and returns the loss.
|
177 |
+
"""
|
178 |
+
if len(self.params_copy) == 0:
|
179 |
+
raise RuntimeError("Need to call extrapolation before calling step.")
|
180 |
+
|
181 |
+
loss = None
|
182 |
+
if closure is not None:
|
183 |
+
loss = closure()
|
184 |
+
|
185 |
+
i = -1
|
186 |
+
for group in self.param_groups:
|
187 |
+
for p in group["params"]:
|
188 |
+
i += 1
|
189 |
+
u = self.update(p, group)
|
190 |
+
if u is None:
|
191 |
+
continue
|
192 |
+
# Update the parameters saved during the extrapolation step
|
193 |
+
p.data = self.params_copy[i].add_(u)
|
194 |
+
|
195 |
+
# Free the old parameters
|
196 |
+
self.params_copy = []
|
197 |
+
return loss
|
198 |
+
|
199 |
+
|
200 |
+
class ExtraAdam(Extragradient):
|
201 |
+
"""Implements the Adam algorithm with extrapolation step.
|
202 |
+
Arguments:
|
203 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
204 |
+
parameter groups
|
205 |
+
lr (float, optional): learning rate (default: 1e-3)
|
206 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
207 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
208 |
+
eps (float, optional): term added to the denominator to improve
|
209 |
+
numerical stability (default: 1e-8)
|
210 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
211 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
212 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(
|
216 |
+
self,
|
217 |
+
params,
|
218 |
+
lr=1e-3,
|
219 |
+
betas=(0.9, 0.999),
|
220 |
+
eps=1e-8,
|
221 |
+
weight_decay=0,
|
222 |
+
amsgrad=False,
|
223 |
+
):
|
224 |
+
if not 0.0 <= lr:
|
225 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
226 |
+
if not 0.0 <= eps:
|
227 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
228 |
+
if not 0.0 <= betas[0] < 1.0:
|
229 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
230 |
+
if not 0.0 <= betas[1] < 1.0:
|
231 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
232 |
+
defaults = dict(
|
233 |
+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
|
234 |
+
)
|
235 |
+
super(ExtraAdam, self).__init__(params, defaults)
|
236 |
+
|
237 |
+
def __setstate__(self, state):
|
238 |
+
super(ExtraAdam, self).__setstate__(state)
|
239 |
+
for group in self.param_groups:
|
240 |
+
group.setdefault("amsgrad", False)
|
241 |
+
|
242 |
+
def update(self, p, group):
|
243 |
+
if p.grad is None:
|
244 |
+
return None
|
245 |
+
grad = p.grad.data
|
246 |
+
if grad.is_sparse:
|
247 |
+
raise RuntimeError(
|
248 |
+
"Adam does not support sparse gradients,"
|
249 |
+
+ " please consider SparseAdam instead"
|
250 |
+
)
|
251 |
+
amsgrad = group["amsgrad"]
|
252 |
+
|
253 |
+
state = self.state[p]
|
254 |
+
|
255 |
+
# State initialization
|
256 |
+
if len(state) == 0:
|
257 |
+
state["step"] = 0
|
258 |
+
# Exponential moving average of gradient values
|
259 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
260 |
+
# Exponential moving average of squared gradient values
|
261 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
262 |
+
if amsgrad:
|
263 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
264 |
+
state["max_exp_avg_sq"] = torch.zeros_like(p.data)
|
265 |
+
|
266 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
267 |
+
if amsgrad:
|
268 |
+
max_exp_avg_sq = state["max_exp_avg_sq"]
|
269 |
+
beta1, beta2 = group["betas"]
|
270 |
+
|
271 |
+
state["step"] += 1
|
272 |
+
|
273 |
+
if group["weight_decay"] != 0:
|
274 |
+
grad = grad.add(group["weight_decay"], p.data)
|
275 |
+
|
276 |
+
# Decay the first and second moment running average coefficient
|
277 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
278 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
279 |
+
if amsgrad:
|
280 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
281 |
+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # type: ignore
|
282 |
+
# Use the max. for normalizing running avg. of gradient
|
283 |
+
denom = max_exp_avg_sq.sqrt().add_(group["eps"]) # type: ignore
|
284 |
+
else:
|
285 |
+
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
286 |
+
|
287 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
288 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
289 |
+
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
|
290 |
+
|
291 |
+
return -step_size * exp_avg / denom
|
climategan/painter.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import climategan.strings as strings
|
6 |
+
from climategan.blocks import InterpolateNearest2d, SPADEResnetBlock
|
7 |
+
from climategan.norms import SpectralNorm
|
8 |
+
|
9 |
+
|
10 |
+
def create_painter(opts, no_init=False, verbose=0):
|
11 |
+
if verbose > 0:
|
12 |
+
print(" - Add PainterSpadeDecoder Painter")
|
13 |
+
return PainterSpadeDecoder(opts)
|
14 |
+
|
15 |
+
|
16 |
+
class PainterSpadeDecoder(nn.Module):
|
17 |
+
def __init__(self, opts):
|
18 |
+
"""Create a SPADE-based decoder, which forwards z and the conditioning
|
19 |
+
tensors seg (in the original paper, conditioning is on a semantic map only).
|
20 |
+
All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink
|
21 |
+
the channel dimension, and an upsampling is applied after each. Therefore
|
22 |
+
2 upsamplings at this point. Then, for each remaining upsamplings
|
23 |
+
(w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3
|
24 |
+
channels, the number of channels is therefore:
|
25 |
+
final_nc = channels(z) * 2 ** (spade_n_up - 2)
|
26 |
+
Args:
|
27 |
+
latent_dim (tuple): z's shape (only the number of channels matters)
|
28 |
+
cond_nc (int): conditioning tensor's expected number of channels
|
29 |
+
spade_n_up (int): Number of total upsamplings from z
|
30 |
+
spade_use_spectral_norm (bool): use spectral normalization?
|
31 |
+
spade_param_free_norm (str): norm to use before SPADE de-normalization
|
32 |
+
spade_kernel_size (int): SPADE conv layers' kernel size
|
33 |
+
Returns:
|
34 |
+
[type]: [description]
|
35 |
+
"""
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
latent_dim = opts.gen.p.latent_dim
|
39 |
+
cond_nc = 3
|
40 |
+
spade_n_up = opts.gen.p.spade_n_up
|
41 |
+
spade_use_spectral_norm = opts.gen.p.spade_use_spectral_norm
|
42 |
+
spade_param_free_norm = opts.gen.p.spade_param_free_norm
|
43 |
+
spade_kernel_size = 3
|
44 |
+
|
45 |
+
self.z_nc = latent_dim
|
46 |
+
self.spade_n_up = spade_n_up
|
47 |
+
|
48 |
+
self.z_h = self.z_w = None
|
49 |
+
|
50 |
+
self.fc = nn.Conv2d(3, latent_dim, 3, padding=1)
|
51 |
+
self.head_0 = SPADEResnetBlock(
|
52 |
+
self.z_nc,
|
53 |
+
self.z_nc,
|
54 |
+
cond_nc,
|
55 |
+
spade_use_spectral_norm,
|
56 |
+
spade_param_free_norm,
|
57 |
+
spade_kernel_size,
|
58 |
+
)
|
59 |
+
|
60 |
+
self.G_middle_0 = SPADEResnetBlock(
|
61 |
+
self.z_nc,
|
62 |
+
self.z_nc,
|
63 |
+
cond_nc,
|
64 |
+
spade_use_spectral_norm,
|
65 |
+
spade_param_free_norm,
|
66 |
+
spade_kernel_size,
|
67 |
+
)
|
68 |
+
self.G_middle_1 = SPADEResnetBlock(
|
69 |
+
self.z_nc,
|
70 |
+
self.z_nc,
|
71 |
+
cond_nc,
|
72 |
+
spade_use_spectral_norm,
|
73 |
+
spade_param_free_norm,
|
74 |
+
spade_kernel_size,
|
75 |
+
)
|
76 |
+
|
77 |
+
self.up_spades = nn.Sequential(
|
78 |
+
*[
|
79 |
+
SPADEResnetBlock(
|
80 |
+
self.z_nc // 2 ** i,
|
81 |
+
self.z_nc // 2 ** (i + 1),
|
82 |
+
cond_nc,
|
83 |
+
spade_use_spectral_norm,
|
84 |
+
spade_param_free_norm,
|
85 |
+
spade_kernel_size,
|
86 |
+
)
|
87 |
+
for i in range(spade_n_up - 2)
|
88 |
+
]
|
89 |
+
)
|
90 |
+
|
91 |
+
self.final_nc = self.z_nc // 2 ** (spade_n_up - 2)
|
92 |
+
|
93 |
+
self.final_spade = SPADEResnetBlock(
|
94 |
+
self.final_nc,
|
95 |
+
self.final_nc,
|
96 |
+
cond_nc,
|
97 |
+
spade_use_spectral_norm,
|
98 |
+
spade_param_free_norm,
|
99 |
+
spade_kernel_size,
|
100 |
+
)
|
101 |
+
self.final_shortcut = None
|
102 |
+
if opts.gen.p.use_final_shortcut:
|
103 |
+
self.final_shortcut = nn.Sequential(
|
104 |
+
*[
|
105 |
+
SpectralNorm(nn.Conv2d(self.final_nc, 3, 1)),
|
106 |
+
nn.BatchNorm2d(3),
|
107 |
+
nn.LeakyReLU(0.2, True),
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
self.conv_img = nn.Conv2d(self.final_nc, 3, 3, padding=1)
|
112 |
+
|
113 |
+
self.upsample = InterpolateNearest2d(scale_factor=2)
|
114 |
+
|
115 |
+
def set_latent_shape(self, shape, is_input=True):
|
116 |
+
"""
|
117 |
+
Sets the latent shape to start the upsampling from, i.e. z_h and z_w.
|
118 |
+
If is_input is True, then this is the actual input shape which should
|
119 |
+
be divided by 2 ** spade_n_up
|
120 |
+
Otherwise, just sets z_h and z_w from shape[-2] and shape[-1]
|
121 |
+
|
122 |
+
Args:
|
123 |
+
shape (tuple): The shape to start sampling from.
|
124 |
+
is_input (bool, optional): Whether to divide shape by 2 ** spade_n_up
|
125 |
+
"""
|
126 |
+
if isinstance(shape, (list, tuple)):
|
127 |
+
self.z_h = shape[-2]
|
128 |
+
self.z_w = shape[-1]
|
129 |
+
elif isinstance(shape, int):
|
130 |
+
self.z_h = self.z_w = shape
|
131 |
+
else:
|
132 |
+
raise ValueError("Unknown shape type:", shape)
|
133 |
+
|
134 |
+
if is_input:
|
135 |
+
self.z_h = self.z_h // (2 ** self.spade_n_up)
|
136 |
+
self.z_w = self.z_w // (2 ** self.spade_n_up)
|
137 |
+
|
138 |
+
def _apply(self, fn):
|
139 |
+
# print("Applying SpadeDecoder", fn)
|
140 |
+
super()._apply(fn)
|
141 |
+
# self.head_0 = fn(self.head_0)
|
142 |
+
# self.G_middle_0 = fn(self.G_middle_0)
|
143 |
+
# self.G_middle_1 = fn(self.G_middle_1)
|
144 |
+
# for i, up in enumerate(self.up_spades):
|
145 |
+
# self.up_spades[i] = fn(up)
|
146 |
+
# self.conv_img = fn(self.conv_img)
|
147 |
+
return self
|
148 |
+
|
149 |
+
def forward(self, z, cond):
|
150 |
+
if z is None:
|
151 |
+
assert self.z_h is not None and self.z_w is not None
|
152 |
+
z = self.fc(F.interpolate(cond, size=(self.z_h, self.z_w)))
|
153 |
+
y = self.head_0(z, cond)
|
154 |
+
y = self.upsample(y)
|
155 |
+
y = self.G_middle_0(y, cond)
|
156 |
+
y = self.upsample(y)
|
157 |
+
y = self.G_middle_1(y, cond)
|
158 |
+
|
159 |
+
for i, up in enumerate(self.up_spades):
|
160 |
+
y = self.upsample(y)
|
161 |
+
y = up(y, cond)
|
162 |
+
|
163 |
+
if self.final_shortcut is not None:
|
164 |
+
cond = self.final_shortcut(y)
|
165 |
+
y = self.final_spade(y, cond)
|
166 |
+
y = self.conv_img(F.leaky_relu(y, 2e-1))
|
167 |
+
y = torch.tanh(y)
|
168 |
+
return y
|
169 |
+
|
170 |
+
def __str__(self):
|
171 |
+
return strings.spadedecoder(self)
|
climategan/strings.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""custom __str__ methods for ClimateGAN's classes
|
2 |
+
"""
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def title(name, color="\033[94m"):
|
8 |
+
name = "==== " + name + " ===="
|
9 |
+
s = "=" * len(name)
|
10 |
+
s = f"{s}\n{name}\n{s}"
|
11 |
+
return f"\033[1m{color}{s}\033[0m"
|
12 |
+
|
13 |
+
|
14 |
+
def generator(G):
|
15 |
+
s = title("OmniGenerator", "\033[95m") + "\n"
|
16 |
+
|
17 |
+
s += str(G.encoder) + "\n\n"
|
18 |
+
for d in G.decoders:
|
19 |
+
if d not in {"a", "t"}:
|
20 |
+
s += str(G.decoders[d]) + "\n\n"
|
21 |
+
elif d == "a":
|
22 |
+
s += "[r & s]\n" + str(G.decoders["a"]["r"]) + "\n\n"
|
23 |
+
else:
|
24 |
+
if G.opts.gen.t.use_bit_conditioning:
|
25 |
+
s += "[bit]\n" + str(G.decoders["t"]) + "\n\n"
|
26 |
+
else:
|
27 |
+
s += "[f & n]\n" + str(G.decoders["t"]["f"]) + "\n\n"
|
28 |
+
return s.strip()
|
29 |
+
|
30 |
+
|
31 |
+
def encoder(E):
|
32 |
+
s = title("Encoder") + "\n"
|
33 |
+
for b in E.model:
|
34 |
+
s += str(b) + "\n"
|
35 |
+
return s.strip()
|
36 |
+
|
37 |
+
|
38 |
+
def get_conv_weight(conv):
|
39 |
+
weight = torch.Tensor(
|
40 |
+
conv.out_channels, conv.in_channels // conv.groups, *conv.kernel_size
|
41 |
+
)
|
42 |
+
return weight.shape
|
43 |
+
|
44 |
+
|
45 |
+
def conv2dblock(obj):
|
46 |
+
name = "{:20}".format("Conv2dBlock")
|
47 |
+
s = ""
|
48 |
+
if "SpectralNorm" in obj.conv.__class__.__name__:
|
49 |
+
s = "SpectralNorm => "
|
50 |
+
w = str(tuple(get_conv_weight(obj.conv.module)))
|
51 |
+
else:
|
52 |
+
w = str(tuple(get_conv_weight(obj.conv)))
|
53 |
+
return f"{name}{s}{w}".strip()
|
54 |
+
|
55 |
+
|
56 |
+
def resblocks(rb):
|
57 |
+
s = "{}\n".format(f"ResBlocks({len(rb.model)})")
|
58 |
+
for i, r in enumerate(rb.model):
|
59 |
+
s += f" - ({i}) {str(r)}\n"
|
60 |
+
return s.strip()
|
61 |
+
|
62 |
+
|
63 |
+
def resblock(rb):
|
64 |
+
s = "{:12}".format("Resblock")
|
65 |
+
return f"{s}{rb.dim} channels, {rb.norm} norm + {rb.activation}"
|
66 |
+
|
67 |
+
|
68 |
+
def basedecoder(bd):
|
69 |
+
s = title(bd.__class__.__name__) + "\n"
|
70 |
+
for b in bd.model:
|
71 |
+
if isinstance(b, nn.Upsample) or "InterpolateNearest2d" in b.__class__.__name__:
|
72 |
+
s += "{:20}".format("Upsample") + "x2\n"
|
73 |
+
else:
|
74 |
+
s += str(b) + "\n"
|
75 |
+
return s.strip()
|
76 |
+
|
77 |
+
|
78 |
+
def spaderesblock(srb):
|
79 |
+
name = "{:20}".format("SPADEResnetBlock") + f"k {srb.kernel_size}, "
|
80 |
+
s = f"{name}{srb.fin} > {srb.fout}, "
|
81 |
+
s += f"param_free_norm: {srb.param_free_norm}, "
|
82 |
+
s += f"spectral_norm: {srb.use_spectral_norm}"
|
83 |
+
return s.strip()
|
84 |
+
|
85 |
+
|
86 |
+
def spadedecoder(sd):
|
87 |
+
s = title(sd.__class__.__name__) + "\n"
|
88 |
+
up = "{:20}x2\n".format("Upsample")
|
89 |
+
s += up
|
90 |
+
s += str(sd.head_0) + "\n"
|
91 |
+
s += up
|
92 |
+
s += str(sd.G_middle_0) + "\n"
|
93 |
+
s += up
|
94 |
+
s += str(sd.G_middle_1) + "\n"
|
95 |
+
for i, u in enumerate(sd.up_spades):
|
96 |
+
s += up
|
97 |
+
s += str(u) + "\n"
|
98 |
+
s += "{:20}".format("Conv2d") + str(tuple(get_conv_weight(sd.conv_img))) + " tanh"
|
99 |
+
return s
|
climategan/trainer.py
ADDED
@@ -0,0 +1,1939 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main component: the trainer handles everything:
|
3 |
+
* initializations
|
4 |
+
* training
|
5 |
+
* saving
|
6 |
+
"""
|
7 |
+
import inspect
|
8 |
+
import warnings
|
9 |
+
from copy import deepcopy
|
10 |
+
from pathlib import Path
|
11 |
+
from time import time
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from comet_ml import ExistingExperiment, Experiment
|
15 |
+
|
16 |
+
warnings.simplefilter("ignore", UserWarning)
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
from addict import Dict
|
21 |
+
from torch import autograd, sigmoid, softmax
|
22 |
+
from torch.cuda.amp import GradScaler, autocast
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
from climategan.data import get_all_loaders
|
26 |
+
from climategan.discriminator import OmniDiscriminator, create_discriminator
|
27 |
+
from climategan.eval_metrics import accuracy, mIOU
|
28 |
+
from climategan.fid import compute_val_fid
|
29 |
+
from climategan.fire import add_fire
|
30 |
+
from climategan.generator import OmniGenerator, create_generator
|
31 |
+
from climategan.logger import Logger
|
32 |
+
from climategan.losses import get_losses
|
33 |
+
from climategan.optim import get_optimizer
|
34 |
+
from climategan.transforms import DiffTransforms
|
35 |
+
from climategan.tutils import (
|
36 |
+
divide_pred,
|
37 |
+
get_num_params,
|
38 |
+
get_WGAN_gradient,
|
39 |
+
lrgb2srgb,
|
40 |
+
normalize,
|
41 |
+
print_num_parameters,
|
42 |
+
shuffle_batch_tuple,
|
43 |
+
srgb2lrgb,
|
44 |
+
vgg_preprocess,
|
45 |
+
zero_grad,
|
46 |
+
)
|
47 |
+
from climategan.utils import (
|
48 |
+
comet_kwargs,
|
49 |
+
div_dict,
|
50 |
+
find_target_size,
|
51 |
+
flatten_opts,
|
52 |
+
get_display_indices,
|
53 |
+
get_existing_comet_id,
|
54 |
+
get_latest_opts,
|
55 |
+
merge,
|
56 |
+
resolve,
|
57 |
+
sum_dict,
|
58 |
+
Timer,
|
59 |
+
)
|
60 |
+
|
61 |
+
try:
|
62 |
+
import torch_xla.core.xla_model as xm # type: ignore
|
63 |
+
except ImportError:
|
64 |
+
pass
|
65 |
+
|
66 |
+
|
67 |
+
class Trainer:
|
68 |
+
"""Main trainer class"""
|
69 |
+
|
70 |
+
def __init__(self, opts, comet_exp=None, verbose=0, device=None):
|
71 |
+
"""Trainer class to gather various model training procedures
|
72 |
+
such as training evaluating saving and logging
|
73 |
+
|
74 |
+
init:
|
75 |
+
* creates an addict.Dict logger
|
76 |
+
* creates logger.exp as a comet_exp experiment if `comet` arg is True
|
77 |
+
* sets the device (1 GPU or CPU)
|
78 |
+
|
79 |
+
Args:
|
80 |
+
opts (addict.Dict): options to configure the trainer, the data, the models
|
81 |
+
comet (bool, optional): whether to log the trainer with comet.ml.
|
82 |
+
Defaults to False.
|
83 |
+
verbose (int, optional): printing level to debug. Defaults to 0.
|
84 |
+
"""
|
85 |
+
super().__init__()
|
86 |
+
|
87 |
+
self.opts = opts
|
88 |
+
self.verbose = verbose
|
89 |
+
self.logger = Logger(self)
|
90 |
+
|
91 |
+
self.losses = None
|
92 |
+
self.G = self.D = None
|
93 |
+
self.real_val_fid_stats = None
|
94 |
+
self.use_pl4m = False
|
95 |
+
self.is_setup = False
|
96 |
+
self.loaders = self.all_loaders = None
|
97 |
+
self.exp = None
|
98 |
+
|
99 |
+
self.current_mode = "train"
|
100 |
+
self.diff_transforms = None
|
101 |
+
self.kitti_pretrain = self.opts.train.kitti.pretrain
|
102 |
+
self.pseudo_training_tasks = set(self.opts.train.pseudo.tasks)
|
103 |
+
|
104 |
+
self.lr_names = {}
|
105 |
+
self.base_display_images = {}
|
106 |
+
self.kitty_display_images = {}
|
107 |
+
self.domain_labels = {"s": 0, "r": 1}
|
108 |
+
|
109 |
+
self.device = device or torch.device(
|
110 |
+
"cuda:0" if torch.cuda.is_available() else "cpu"
|
111 |
+
)
|
112 |
+
|
113 |
+
if isinstance(comet_exp, Experiment):
|
114 |
+
self.exp = comet_exp
|
115 |
+
|
116 |
+
if self.opts.train.amp:
|
117 |
+
optimizers = [
|
118 |
+
self.opts.gen.opt.optimizer.lower(),
|
119 |
+
self.opts.dis.opt.optimizer.lower(),
|
120 |
+
]
|
121 |
+
if "extraadam" in optimizers:
|
122 |
+
raise ValueError(
|
123 |
+
"AMP does not work with ExtraAdam ({})".format(optimizers)
|
124 |
+
)
|
125 |
+
self.grad_scaler_d = GradScaler()
|
126 |
+
self.grad_scaler_g = GradScaler()
|
127 |
+
|
128 |
+
# -------------------------------
|
129 |
+
# ----- Legacy Overwrites -----
|
130 |
+
# -------------------------------
|
131 |
+
if (
|
132 |
+
self.opts.gen.s.depth_feat_fusion is True
|
133 |
+
or self.opts.gen.s.depth_dada_fusion is True
|
134 |
+
):
|
135 |
+
self.opts.gen.s.use_dada = True
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def paint_and_mask(self, image_batch, mask_batch=None, resolution="approx"):
|
139 |
+
"""
|
140 |
+
Paints a batch of images (or a single image with a batch dim of 1). If
|
141 |
+
masks are not provided, they are inferred from the masker.
|
142 |
+
Resolution can either be the train-time resolution or the closest
|
143 |
+
multiple of 2 ** spade_n_up
|
144 |
+
|
145 |
+
Operations performed without gradient
|
146 |
+
|
147 |
+
If resolution == "approx" then the output image has the shape:
|
148 |
+
(dim // 2 ** spade_n_up) * 2 ** spade_n_up, for dim in [height, width]
|
149 |
+
eg: (1000, 1300) => (896, 1280) for spade_n_up = 7
|
150 |
+
If resolution == "exact" then the output image has the same shape:
|
151 |
+
we first process in "approx" mode then upsample bilinear
|
152 |
+
If resolution == "basic" image output shape is the train-time's
|
153 |
+
(typically 640x640)
|
154 |
+
If resolution == "upsample" image is inferred as "basic" and
|
155 |
+
then upsampled to original size
|
156 |
+
|
157 |
+
Args:
|
158 |
+
image_batch (torch.Tensor): 4D batch of images to flood
|
159 |
+
mask_batch (torch.Tensor, optional): Masks for the images.
|
160 |
+
Defaults to None (infer with Masker).
|
161 |
+
resolution (str, optional): "approx", "exact" or False
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
torch.Tensor: N x C x H x W where H and W depend on `resolution`
|
165 |
+
"""
|
166 |
+
assert resolution in {"approx", "exact", "basic", "upsample"}
|
167 |
+
previous_mode = self.current_mode
|
168 |
+
if previous_mode == "train":
|
169 |
+
self.eval_mode()
|
170 |
+
|
171 |
+
if mask_batch is None:
|
172 |
+
mask_batch = self.G.mask(x=image_batch)
|
173 |
+
else:
|
174 |
+
assert len(image_batch) == len(mask_batch)
|
175 |
+
assert image_batch.shape[-2:] == mask_batch.shape[-2:]
|
176 |
+
|
177 |
+
if resolution not in {"approx", "exact"}:
|
178 |
+
painted = self.G.paint(mask_batch, image_batch)
|
179 |
+
|
180 |
+
if resolution == "upsample":
|
181 |
+
painted = nn.functional.interpolate(
|
182 |
+
painted, size=image_batch.shape[-2:], mode="bilinear"
|
183 |
+
)
|
184 |
+
else:
|
185 |
+
# save latent shape
|
186 |
+
zh = self.G.painter.z_h
|
187 |
+
zw = self.G.painter.z_w
|
188 |
+
# adapt latent shape to approximately keep the resolution
|
189 |
+
self.G.painter.z_h = (
|
190 |
+
image_batch.shape[-2] // 2**self.opts.gen.p.spade_n_up
|
191 |
+
)
|
192 |
+
self.G.painter.z_w = (
|
193 |
+
image_batch.shape[-1] // 2**self.opts.gen.p.spade_n_up
|
194 |
+
)
|
195 |
+
|
196 |
+
painted = self.G.paint(mask_batch, image_batch)
|
197 |
+
|
198 |
+
self.G.painter.z_h = zh
|
199 |
+
self.G.painter.z_w = zw
|
200 |
+
if resolution == "exact":
|
201 |
+
painted = nn.functional.interpolate(
|
202 |
+
painted, size=image_batch.shape[-2:], mode="bilinear"
|
203 |
+
)
|
204 |
+
|
205 |
+
if previous_mode == "train":
|
206 |
+
self.train_mode()
|
207 |
+
|
208 |
+
return painted
|
209 |
+
|
210 |
+
def _p(self, *args, **kwargs):
|
211 |
+
"""
|
212 |
+
verbose-dependant print util
|
213 |
+
"""
|
214 |
+
if self.verbose > 0:
|
215 |
+
print(*args, **kwargs)
|
216 |
+
|
217 |
+
@torch.no_grad()
|
218 |
+
def infer_all(
|
219 |
+
self,
|
220 |
+
x,
|
221 |
+
numpy=True,
|
222 |
+
stores={},
|
223 |
+
bin_value=-1,
|
224 |
+
half=False,
|
225 |
+
xla=False,
|
226 |
+
cloudy=False,
|
227 |
+
auto_resize_640=False,
|
228 |
+
ignore_event=set(),
|
229 |
+
return_masks=False,
|
230 |
+
):
|
231 |
+
"""
|
232 |
+
Create a dictionnary of events from a numpy or tensor,
|
233 |
+
single or batch image data.
|
234 |
+
|
235 |
+
stores is a dictionnary of times for the Timer class.
|
236 |
+
|
237 |
+
bin_value is used to binarize (or not) flood masks
|
238 |
+
"""
|
239 |
+
assert self.is_setup
|
240 |
+
assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
|
241 |
+
|
242 |
+
# convert numpy to tensor
|
243 |
+
if not isinstance(x, torch.Tensor):
|
244 |
+
x = torch.tensor(x, device=self.device)
|
245 |
+
|
246 |
+
# add batch dimension
|
247 |
+
if len(x.shape) == 3:
|
248 |
+
x.unsqueeze_(0)
|
249 |
+
|
250 |
+
# permute channels as second dimension
|
251 |
+
if x.shape[1] != 3:
|
252 |
+
assert x.shape[-1] == 3, f"Unknown x shape to permute {x.shape}"
|
253 |
+
x = x.permute(0, 3, 1, 2)
|
254 |
+
|
255 |
+
# send to device
|
256 |
+
if x.device != self.device:
|
257 |
+
x = x.to(self.device)
|
258 |
+
|
259 |
+
# interpolate to standard input size
|
260 |
+
if auto_resize_640 and (x.shape[-1] != 640 or x.shape[-2] != 640):
|
261 |
+
x = torch.nn.functional.interpolate(x, (640, 640), mode="bilinear")
|
262 |
+
|
263 |
+
if half:
|
264 |
+
x = x.half()
|
265 |
+
|
266 |
+
# adjust painter's latent vector
|
267 |
+
self.G.painter.set_latent_shape(x.shape, True)
|
268 |
+
|
269 |
+
with Timer(store=stores.get("all events", [])):
|
270 |
+
# encode
|
271 |
+
with Timer(store=stores.get("encode", [])):
|
272 |
+
z = self.G.encode(x)
|
273 |
+
if xla:
|
274 |
+
xm.mark_step()
|
275 |
+
|
276 |
+
# predict from masker
|
277 |
+
with Timer(store=stores.get("depth", [])):
|
278 |
+
depth, z_depth = self.G.decoders["d"](z)
|
279 |
+
if xla:
|
280 |
+
xm.mark_step()
|
281 |
+
with Timer(store=stores.get("segmentation", [])):
|
282 |
+
segmentation = self.G.decoders["s"](z, z_depth)
|
283 |
+
if xla:
|
284 |
+
xm.mark_step()
|
285 |
+
with Timer(store=stores.get("mask", [])):
|
286 |
+
cond = self.G.make_m_cond(depth, segmentation, x)
|
287 |
+
mask = self.G.mask(z=z, cond=cond, z_depth=z_depth)
|
288 |
+
if xla:
|
289 |
+
xm.mark_step()
|
290 |
+
|
291 |
+
# apply events
|
292 |
+
if "wildfire" not in ignore_event:
|
293 |
+
with Timer(store=stores.get("wildfire", [])):
|
294 |
+
wildfire = self.compute_fire(x, seg_preds=segmentation)
|
295 |
+
if "smog" not in ignore_event:
|
296 |
+
with Timer(store=stores.get("smog", [])):
|
297 |
+
smog = self.compute_smog(x, d=depth, s=segmentation)
|
298 |
+
if "flood" not in ignore_event:
|
299 |
+
with Timer(store=stores.get("flood", [])):
|
300 |
+
flood = self.compute_flood(
|
301 |
+
x,
|
302 |
+
m=mask,
|
303 |
+
s=segmentation,
|
304 |
+
cloudy=cloudy,
|
305 |
+
bin_value=bin_value,
|
306 |
+
)
|
307 |
+
|
308 |
+
if xla:
|
309 |
+
xm.mark_step()
|
310 |
+
|
311 |
+
if numpy:
|
312 |
+
with Timer(store=stores.get("numpy", [])):
|
313 |
+
# normalize to 0-1
|
314 |
+
flood = normalize(flood).cpu()
|
315 |
+
smog = normalize(smog).cpu()
|
316 |
+
wildfire = normalize(wildfire).cpu()
|
317 |
+
|
318 |
+
# convert to numpy
|
319 |
+
flood = flood.permute(0, 2, 3, 1).numpy()
|
320 |
+
smog = smog.permute(0, 2, 3, 1).numpy()
|
321 |
+
wildfire = wildfire.permute(0, 2, 3, 1).numpy()
|
322 |
+
|
323 |
+
# convert to 0-255 uint8
|
324 |
+
flood = (flood * 255).astype(np.uint8)
|
325 |
+
smog = (smog * 255).astype(np.uint8)
|
326 |
+
wildfire = (wildfire * 255).astype(np.uint8)
|
327 |
+
|
328 |
+
output_data = {"flood": flood, "wildfire": wildfire, "smog": smog}
|
329 |
+
if return_masks:
|
330 |
+
output_data["mask"] = (
|
331 |
+
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
332 |
+
)
|
333 |
+
|
334 |
+
return output_data
|
335 |
+
|
336 |
+
@classmethod
|
337 |
+
def resume_from_path(
|
338 |
+
cls,
|
339 |
+
path,
|
340 |
+
overrides={},
|
341 |
+
setup=True,
|
342 |
+
inference=False,
|
343 |
+
new_exp=False,
|
344 |
+
device=None,
|
345 |
+
verbose=1,
|
346 |
+
):
|
347 |
+
"""
|
348 |
+
Resume and optionally setup a trainer from a specific path,
|
349 |
+
using the latest opts and checkpoint. Requires path to contain opts.yaml
|
350 |
+
(or increased), url.txt (or increased) and checkpoints/
|
351 |
+
|
352 |
+
Args:
|
353 |
+
path (str | pathlib.Path): Trainer to resume
|
354 |
+
overrides (dict, optional): Override loaded opts with those. Defaults to {}.
|
355 |
+
setup (bool, optional): Wether or not to setup the trainer before
|
356 |
+
returning it. Defaults to True.
|
357 |
+
inference (bool, optional): Setup should be done in inference mode or not.
|
358 |
+
Defaults to False.
|
359 |
+
new_exp (bool, optional): Re-use existing comet exp in path or create
|
360 |
+
a new one? Defaults to False.
|
361 |
+
device (torch.device, optional): Device to use
|
362 |
+
|
363 |
+
Returns:
|
364 |
+
climategan.Trainer: Loaded and resumed trainer
|
365 |
+
"""
|
366 |
+
p = resolve(path)
|
367 |
+
assert p.exists()
|
368 |
+
|
369 |
+
c = p / "checkpoints"
|
370 |
+
assert c.exists() and c.is_dir()
|
371 |
+
|
372 |
+
opts = get_latest_opts(p)
|
373 |
+
opts = Dict(merge(overrides, opts))
|
374 |
+
opts.train.resume = True
|
375 |
+
|
376 |
+
if new_exp is None:
|
377 |
+
exp = None
|
378 |
+
elif new_exp is True:
|
379 |
+
exp = Experiment(project_name="climategan", **comet_kwargs)
|
380 |
+
exp.log_asset_folder(
|
381 |
+
str(resolve(Path(__file__)).parent),
|
382 |
+
recursive=True,
|
383 |
+
log_file_name=True,
|
384 |
+
)
|
385 |
+
exp.log_parameters(flatten_opts(opts))
|
386 |
+
else:
|
387 |
+
comet_id = get_existing_comet_id(p)
|
388 |
+
exp = ExistingExperiment(previous_experiment=comet_id, **comet_kwargs)
|
389 |
+
|
390 |
+
trainer = cls(opts, comet_exp=exp, device=device, verbose=verbose)
|
391 |
+
|
392 |
+
if setup:
|
393 |
+
trainer.setup(inference=inference)
|
394 |
+
return trainer
|
395 |
+
|
396 |
+
def save(self):
|
397 |
+
save_dir = Path(self.opts.output_path) / Path("checkpoints")
|
398 |
+
save_dir.mkdir(exist_ok=True)
|
399 |
+
save_path = save_dir / "latest_ckpt.pth"
|
400 |
+
|
401 |
+
# Construct relevant state dicts / optims:
|
402 |
+
# Save at least G
|
403 |
+
save_dict = {
|
404 |
+
"epoch": self.logger.epoch,
|
405 |
+
"G": self.G.state_dict(),
|
406 |
+
"g_opt": self.g_opt.state_dict(),
|
407 |
+
"step": self.logger.global_step,
|
408 |
+
}
|
409 |
+
|
410 |
+
if self.D is not None and get_num_params(self.D) > 0:
|
411 |
+
save_dict["D"] = self.D.state_dict()
|
412 |
+
save_dict["d_opt"] = self.d_opt.state_dict()
|
413 |
+
|
414 |
+
if (
|
415 |
+
self.logger.epoch >= self.opts.train.min_save_epoch
|
416 |
+
and self.logger.epoch % self.opts.train.save_n_epochs == 0
|
417 |
+
):
|
418 |
+
torch.save(save_dict, save_dir / f"epoch_{self.logger.epoch}_ckpt.pth")
|
419 |
+
|
420 |
+
torch.save(save_dict, save_path)
|
421 |
+
|
422 |
+
def resume(self, inference=False):
|
423 |
+
tpu = "xla" in str(self.device)
|
424 |
+
if tpu:
|
425 |
+
print("Resuming on TPU:", self.device)
|
426 |
+
|
427 |
+
m_path = Path(self.opts.load_paths.m)
|
428 |
+
p_path = Path(self.opts.load_paths.p)
|
429 |
+
pm_path = Path(self.opts.load_paths.pm)
|
430 |
+
output_path = Path(self.opts.output_path)
|
431 |
+
|
432 |
+
map_loc = self.device if not tpu else "cpu"
|
433 |
+
|
434 |
+
if "m" in self.opts.tasks and "p" in self.opts.tasks:
|
435 |
+
# ----------------------------------------
|
436 |
+
# ----- Masker and Painter Loading -----
|
437 |
+
# ----------------------------------------
|
438 |
+
|
439 |
+
# want to resume a pm model but no path was provided:
|
440 |
+
# resume a single pm model from output_path
|
441 |
+
if all([str(p) == "none" for p in [m_path, p_path, pm_path]]):
|
442 |
+
checkpoint_path = output_path / "checkpoints/latest_ckpt.pth"
|
443 |
+
print("Resuming P+M model from", str(checkpoint_path))
|
444 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_loc)
|
445 |
+
|
446 |
+
# want to resume a pm model with a pm_path provided:
|
447 |
+
# resume a single pm model from load_paths.pm
|
448 |
+
# depending on whether a dir or a file is specified
|
449 |
+
elif str(pm_path) != "none":
|
450 |
+
assert pm_path.exists()
|
451 |
+
|
452 |
+
if pm_path.is_dir():
|
453 |
+
checkpoint_path = pm_path / "checkpoints/latest_ckpt.pth"
|
454 |
+
else:
|
455 |
+
assert pm_path.suffix == ".pth"
|
456 |
+
checkpoint_path = pm_path
|
457 |
+
|
458 |
+
print("Resuming P+M model from", str(checkpoint_path))
|
459 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_loc)
|
460 |
+
|
461 |
+
# want to resume a pm model, pm_path not provided:
|
462 |
+
# m_path and p_path must be provided as dirs or pth files
|
463 |
+
elif m_path != p_path:
|
464 |
+
assert m_path.exists()
|
465 |
+
assert p_path.exists()
|
466 |
+
|
467 |
+
if m_path.is_dir():
|
468 |
+
m_path = m_path / "checkpoints/latest_ckpt.pth"
|
469 |
+
|
470 |
+
if p_path.is_dir():
|
471 |
+
p_path = p_path / "checkpoints/latest_ckpt.pth"
|
472 |
+
|
473 |
+
assert m_path.suffix == ".pth"
|
474 |
+
assert p_path.suffix == ".pth"
|
475 |
+
|
476 |
+
print(f"Resuming P+M model from \n -{p_path} \nand \n -{m_path}")
|
477 |
+
m_checkpoint = torch.load(m_path, map_location=map_loc)
|
478 |
+
p_checkpoint = torch.load(p_path, map_location=map_loc)
|
479 |
+
checkpoint = merge(m_checkpoint, p_checkpoint)
|
480 |
+
|
481 |
+
else:
|
482 |
+
raise ValueError(
|
483 |
+
"Cannot resume a P+M model with provided load_paths:\n{}".format(
|
484 |
+
self.opts.load_paths
|
485 |
+
)
|
486 |
+
)
|
487 |
+
|
488 |
+
else:
|
489 |
+
# ----------------------------------
|
490 |
+
# ----- Single Model Loading -----
|
491 |
+
# ----------------------------------
|
492 |
+
|
493 |
+
# cannot specify both paths
|
494 |
+
if str(m_path) != "none" and str(p_path) != "none":
|
495 |
+
raise ValueError(
|
496 |
+
"Opts tasks are {} but received 2 values for the load_paths".format(
|
497 |
+
self.opts.tasks
|
498 |
+
)
|
499 |
+
)
|
500 |
+
|
501 |
+
# specified m
|
502 |
+
elif str(m_path) != "none":
|
503 |
+
assert m_path.exists()
|
504 |
+
assert "m" in self.opts.tasks
|
505 |
+
model = "M"
|
506 |
+
if m_path.is_dir():
|
507 |
+
m_path = m_path / "checkpoints/latest_ckpt.pth"
|
508 |
+
checkpoint_path = m_path
|
509 |
+
|
510 |
+
# specified m
|
511 |
+
elif str(p_path) != "none":
|
512 |
+
assert p_path.exists()
|
513 |
+
assert "p" in self.opts.tasks
|
514 |
+
model = "P"
|
515 |
+
if p_path.is_dir():
|
516 |
+
p_path = p_path / "checkpoints/latest_ckpt.pth"
|
517 |
+
checkpoint_path = p_path
|
518 |
+
|
519 |
+
# specified neither p nor m: resume from output_path
|
520 |
+
else:
|
521 |
+
model = "P" if "p" in self.opts.tasks else "M"
|
522 |
+
checkpoint_path = output_path / "checkpoints/latest_ckpt.pth"
|
523 |
+
|
524 |
+
print(f"Resuming {model} model from {checkpoint_path}")
|
525 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_loc)
|
526 |
+
|
527 |
+
# On TPUs must send the data to the xla device as it cannot be mapped
|
528 |
+
# there directly from torch.load
|
529 |
+
if tpu:
|
530 |
+
checkpoint = xm.send_cpu_data_to_device(checkpoint, self.device)
|
531 |
+
|
532 |
+
# -----------------------
|
533 |
+
# ----- Restore G -----
|
534 |
+
# -----------------------
|
535 |
+
if inference:
|
536 |
+
incompatible_keys = self.G.load_state_dict(checkpoint["G"], strict=False)
|
537 |
+
if incompatible_keys.missing_keys:
|
538 |
+
print("WARNING: Missing keys in self.G.load_state_dict, keeping inits")
|
539 |
+
print(incompatible_keys.missing_keys)
|
540 |
+
if incompatible_keys.unexpected_keys:
|
541 |
+
print("WARNING: Ignoring Unexpected keys in self.G.load_state_dict")
|
542 |
+
print(incompatible_keys.unexpected_keys)
|
543 |
+
else:
|
544 |
+
self.G.load_state_dict(checkpoint["G"])
|
545 |
+
|
546 |
+
if inference:
|
547 |
+
# only G is needed to infer
|
548 |
+
print("Done loading checkpoints.")
|
549 |
+
return
|
550 |
+
|
551 |
+
self.g_opt.load_state_dict(checkpoint["g_opt"])
|
552 |
+
|
553 |
+
# ------------------------------
|
554 |
+
# ----- Resume scheduler -----
|
555 |
+
# ------------------------------
|
556 |
+
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
557 |
+
for _ in range(self.logger.epoch + 1):
|
558 |
+
self.update_learning_rates()
|
559 |
+
|
560 |
+
# -----------------------
|
561 |
+
# ----- Restore D -----
|
562 |
+
# -----------------------
|
563 |
+
if self.D is not None and get_num_params(self.D) > 0:
|
564 |
+
self.D.load_state_dict(checkpoint["D"])
|
565 |
+
self.d_opt.load_state_dict(checkpoint["d_opt"])
|
566 |
+
|
567 |
+
# ---------------------------
|
568 |
+
# ----- Resore logger -----
|
569 |
+
# ---------------------------
|
570 |
+
self.logger.epoch = checkpoint["epoch"]
|
571 |
+
self.logger.global_step = checkpoint["step"]
|
572 |
+
self.exp.log_text(
|
573 |
+
"Resuming from epoch {} & step {}".format(
|
574 |
+
checkpoint["epoch"], checkpoint["step"]
|
575 |
+
)
|
576 |
+
)
|
577 |
+
# Round step to even number for extraGradient
|
578 |
+
if self.logger.global_step % 2 != 0:
|
579 |
+
self.logger.global_step += 1
|
580 |
+
|
581 |
+
def eval_mode(self):
|
582 |
+
"""
|
583 |
+
Set trainer's models in eval mode
|
584 |
+
"""
|
585 |
+
if self.G is not None:
|
586 |
+
self.G.eval()
|
587 |
+
if self.D is not None:
|
588 |
+
self.D.eval()
|
589 |
+
self.current_mode = "eval"
|
590 |
+
|
591 |
+
def train_mode(self):
|
592 |
+
"""
|
593 |
+
Set trainer's models in train mode
|
594 |
+
"""
|
595 |
+
if self.G is not None:
|
596 |
+
self.G.train()
|
597 |
+
if self.D is not None:
|
598 |
+
self.D.train()
|
599 |
+
|
600 |
+
self.current_mode = "train"
|
601 |
+
|
602 |
+
def assert_z_matches_x(self, x, z):
|
603 |
+
assert x.shape[0] == (
|
604 |
+
z.shape[0] if not isinstance(z, (list, tuple)) else z[0].shape[0]
|
605 |
+
), "x-> {}, z->{}".format(
|
606 |
+
x.shape, z.shape if not isinstance(z, (list, tuple)) else z[0].shape
|
607 |
+
)
|
608 |
+
|
609 |
+
def batch_to_device(self, b):
|
610 |
+
"""sends the data in b to self.device
|
611 |
+
|
612 |
+
Args:
|
613 |
+
b (dict): the batch dictionnay
|
614 |
+
|
615 |
+
Returns:
|
616 |
+
dict: the batch dictionnary with its "data" field sent to self.device
|
617 |
+
"""
|
618 |
+
for task, tensor in b["data"].items():
|
619 |
+
b["data"][task] = tensor.to(self.device)
|
620 |
+
return b
|
621 |
+
|
622 |
+
def sample_painter_z(self, batch_size):
|
623 |
+
return self.G.sample_painter_z(batch_size, self.device)
|
624 |
+
|
625 |
+
@property
|
626 |
+
def train_loaders(self):
|
627 |
+
"""Get a zip of all training loaders
|
628 |
+
|
629 |
+
Returns:
|
630 |
+
generator: zip generator yielding tuples:
|
631 |
+
(batch_rf, batch_rn, batch_sf, batch_sn)
|
632 |
+
"""
|
633 |
+
return zip(*list(self.loaders["train"].values()))
|
634 |
+
|
635 |
+
@property
|
636 |
+
def val_loaders(self):
|
637 |
+
"""Get a zip of all validation loaders
|
638 |
+
|
639 |
+
Returns:
|
640 |
+
generator: zip generator yielding tuples:
|
641 |
+
(batch_rf, batch_rn, batch_sf, batch_sn)
|
642 |
+
"""
|
643 |
+
return zip(*list(self.loaders["val"].values()))
|
644 |
+
|
645 |
+
def compute_latent_shape(self):
|
646 |
+
"""Compute the latent shape, i.e. the Encoder's output shape,
|
647 |
+
from a batch.
|
648 |
+
|
649 |
+
Raises:
|
650 |
+
ValueError: If no loader, the latent_shape cannot be inferred
|
651 |
+
|
652 |
+
Returns:
|
653 |
+
tuple: (c, h, w)
|
654 |
+
"""
|
655 |
+
x = None
|
656 |
+
for mode in self.all_loaders:
|
657 |
+
for domain in self.all_loaders.loaders[mode]:
|
658 |
+
x = (
|
659 |
+
self.all_loaders[mode][domain]
|
660 |
+
.dataset[0]["data"]["x"]
|
661 |
+
.to(self.device)
|
662 |
+
)
|
663 |
+
break
|
664 |
+
if x is not None:
|
665 |
+
break
|
666 |
+
|
667 |
+
if x is None:
|
668 |
+
raise ValueError("No batch found to compute_latent_shape")
|
669 |
+
|
670 |
+
x = x.unsqueeze(0)
|
671 |
+
z = self.G.encode(x)
|
672 |
+
return z.shape[1:] if not isinstance(z, (list, tuple)) else z[0].shape[1:]
|
673 |
+
|
674 |
+
def g_opt_step(self):
|
675 |
+
"""Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation
|
676 |
+
step every other step
|
677 |
+
"""
|
678 |
+
if "extra" in self.opts.gen.opt.optimizer.lower() and (
|
679 |
+
self.logger.global_step % 2 == 0
|
680 |
+
):
|
681 |
+
self.g_opt.extrapolation()
|
682 |
+
else:
|
683 |
+
self.g_opt.step()
|
684 |
+
|
685 |
+
def d_opt_step(self):
|
686 |
+
"""Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation
|
687 |
+
step every other step
|
688 |
+
"""
|
689 |
+
if "extra" in self.opts.dis.opt.optimizer.lower() and (
|
690 |
+
self.logger.global_step % 2 == 0
|
691 |
+
):
|
692 |
+
self.d_opt.extrapolation()
|
693 |
+
else:
|
694 |
+
self.d_opt.step()
|
695 |
+
|
696 |
+
def update_learning_rates(self):
|
697 |
+
if self.g_scheduler is not None:
|
698 |
+
self.g_scheduler.step()
|
699 |
+
if self.d_scheduler is not None:
|
700 |
+
self.d_scheduler.step()
|
701 |
+
|
702 |
+
def setup(self, inference=False):
|
703 |
+
"""Prepare the trainer before it can be used to train the models:
|
704 |
+
* initialize G and D
|
705 |
+
* creates 2 optimizers
|
706 |
+
"""
|
707 |
+
self.logger.global_step = 0
|
708 |
+
start_time = time()
|
709 |
+
self.logger.time.start_time = start_time
|
710 |
+
verbose = self.verbose
|
711 |
+
|
712 |
+
if not inference:
|
713 |
+
self.all_loaders = get_all_loaders(self.opts)
|
714 |
+
|
715 |
+
# -----------------------
|
716 |
+
# ----- Generator -----
|
717 |
+
# -----------------------
|
718 |
+
__t = time()
|
719 |
+
print("Creating generator...")
|
720 |
+
|
721 |
+
self.G: OmniGenerator = create_generator(
|
722 |
+
self.opts, device=self.device, no_init=inference, verbose=verbose
|
723 |
+
)
|
724 |
+
|
725 |
+
self.has_painter = get_num_params(self.G.painter) or self.G.load_val_painter()
|
726 |
+
|
727 |
+
if self.has_painter:
|
728 |
+
self.G.painter.set_latent_shape(find_target_size(self.opts, "x"), True)
|
729 |
+
|
730 |
+
print(f"Generator OK in {time() - __t:.1f}s.")
|
731 |
+
|
732 |
+
if inference: # Inference mode: no more than a Generator needed
|
733 |
+
print("Inference mode: no Discriminator, no optimizers")
|
734 |
+
print_num_parameters(self)
|
735 |
+
self.switch_data(to="base")
|
736 |
+
if self.opts.train.resume:
|
737 |
+
self.resume(True)
|
738 |
+
self.eval_mode()
|
739 |
+
print("Trainer is in evaluation mode.")
|
740 |
+
print("Setup done.")
|
741 |
+
self.is_setup = True
|
742 |
+
return
|
743 |
+
|
744 |
+
# ---------------------------
|
745 |
+
# ----- Discriminator -----
|
746 |
+
# ---------------------------
|
747 |
+
|
748 |
+
self.D: OmniDiscriminator = create_discriminator(
|
749 |
+
self.opts, self.device, verbose=verbose
|
750 |
+
)
|
751 |
+
print("Discriminator OK.")
|
752 |
+
|
753 |
+
print_num_parameters(self)
|
754 |
+
|
755 |
+
# --------------------------
|
756 |
+
# ----- Optimization -----
|
757 |
+
# --------------------------
|
758 |
+
# Get different optimizers for each task (different learning rates)
|
759 |
+
self.g_opt, self.g_scheduler, self.lr_names["G"] = get_optimizer(
|
760 |
+
self.G, self.opts.gen.opt, self.opts.tasks
|
761 |
+
)
|
762 |
+
|
763 |
+
if get_num_params(self.D) > 0:
|
764 |
+
self.d_opt, self.d_scheduler, self.lr_names["D"] = get_optimizer(
|
765 |
+
self.D, self.opts.dis.opt, self.opts.tasks, True
|
766 |
+
)
|
767 |
+
else:
|
768 |
+
self.d_opt, self.d_scheduler = None, None
|
769 |
+
|
770 |
+
self.losses = get_losses(self.opts, verbose, device=self.device)
|
771 |
+
|
772 |
+
if "p" in self.opts.tasks and self.opts.gen.p.diff_aug.use:
|
773 |
+
self.diff_transforms = DiffTransforms(self.opts.gen.p.diff_aug)
|
774 |
+
|
775 |
+
if verbose > 0:
|
776 |
+
for mode, mode_dict in self.all_loaders.items():
|
777 |
+
for domain, domain_loader in mode_dict.items():
|
778 |
+
print(
|
779 |
+
"Loader {} {} : {}".format(
|
780 |
+
mode, domain, len(domain_loader.dataset)
|
781 |
+
)
|
782 |
+
)
|
783 |
+
|
784 |
+
# ----------------------------
|
785 |
+
# ----- Display images -----
|
786 |
+
# ----------------------------
|
787 |
+
self.set_display_images()
|
788 |
+
|
789 |
+
# -------------------------------
|
790 |
+
# ----- Log Architectures -----
|
791 |
+
# -------------------------------
|
792 |
+
self.logger.log_architecture()
|
793 |
+
|
794 |
+
# -----------------------------
|
795 |
+
# ----- Set data source -----
|
796 |
+
# -----------------------------
|
797 |
+
if self.kitti_pretrain:
|
798 |
+
self.switch_data(to="kitti")
|
799 |
+
else:
|
800 |
+
self.switch_data(to="base")
|
801 |
+
|
802 |
+
# -------------------------
|
803 |
+
# ----- Setup Done. -----
|
804 |
+
# -------------------------
|
805 |
+
print(" " * 50, end="\r")
|
806 |
+
print("Done creating display images")
|
807 |
+
|
808 |
+
if self.opts.train.resume:
|
809 |
+
print("Resuming Model (inference: False)")
|
810 |
+
self.resume(False)
|
811 |
+
else:
|
812 |
+
print("Not resuming: starting a new model")
|
813 |
+
|
814 |
+
print("Setup done.")
|
815 |
+
self.is_setup = True
|
816 |
+
|
817 |
+
def switch_data(self, to="kitti"):
|
818 |
+
caller = inspect.stack()[1].function
|
819 |
+
print(f"[{caller}] Switching data source to", to)
|
820 |
+
self.data_source = to
|
821 |
+
if to == "kitti":
|
822 |
+
self.display_images = self.kitty_display_images
|
823 |
+
if self.all_loaders is not None:
|
824 |
+
self.loaders = {
|
825 |
+
mode: {"s": self.all_loaders[mode]["kitti"]}
|
826 |
+
for mode in self.all_loaders
|
827 |
+
}
|
828 |
+
else:
|
829 |
+
self.display_images = self.base_display_images
|
830 |
+
if self.all_loaders is not None:
|
831 |
+
self.loaders = {
|
832 |
+
mode: {
|
833 |
+
domain: self.all_loaders[mode][domain]
|
834 |
+
for domain in self.all_loaders[mode]
|
835 |
+
if domain != "kitti"
|
836 |
+
}
|
837 |
+
for mode in self.all_loaders
|
838 |
+
}
|
839 |
+
if (
|
840 |
+
self.logger.global_step % 2 != 0
|
841 |
+
and "extra" in self.opts.dis.opt.optimizer.lower()
|
842 |
+
):
|
843 |
+
print(
|
844 |
+
"Warning: artificially bumping step to run an extrapolation step first."
|
845 |
+
)
|
846 |
+
self.logger.global_step += 1
|
847 |
+
|
848 |
+
def set_display_images(self, use_all=False):
|
849 |
+
for mode, mode_dict in self.all_loaders.items():
|
850 |
+
|
851 |
+
if self.kitti_pretrain:
|
852 |
+
self.kitty_display_images[mode] = {}
|
853 |
+
self.base_display_images[mode] = {}
|
854 |
+
|
855 |
+
for domain in mode_dict:
|
856 |
+
|
857 |
+
if self.kitti_pretrain and domain == "kitti":
|
858 |
+
target_dict = self.kitty_display_images
|
859 |
+
else:
|
860 |
+
if domain == "kitti":
|
861 |
+
continue
|
862 |
+
target_dict = self.base_display_images
|
863 |
+
|
864 |
+
dataset = self.all_loaders[mode][domain].dataset
|
865 |
+
display_indices = (
|
866 |
+
get_display_indices(self.opts, domain, len(dataset))
|
867 |
+
if not use_all
|
868 |
+
else list(range(len(dataset)))
|
869 |
+
)
|
870 |
+
ldis = len(display_indices)
|
871 |
+
print(
|
872 |
+
f" Creating {ldis} {mode} {domain} display images...",
|
873 |
+
end="\r",
|
874 |
+
flush=True,
|
875 |
+
)
|
876 |
+
target_dict[mode][domain] = [
|
877 |
+
Dict(dataset[i])
|
878 |
+
for i in display_indices
|
879 |
+
if (print(f"({i})", end="\r") is None and i < len(dataset))
|
880 |
+
]
|
881 |
+
if self.exp is not None:
|
882 |
+
for im_id, d in enumerate(target_dict[mode][domain]):
|
883 |
+
self.exp.log_parameter(
|
884 |
+
"display_image_{}_{}_{}".format(mode, domain, im_id),
|
885 |
+
d["paths"],
|
886 |
+
)
|
887 |
+
|
888 |
+
def train(self):
|
889 |
+
"""For each epoch:
|
890 |
+
* train
|
891 |
+
* eval
|
892 |
+
* save
|
893 |
+
"""
|
894 |
+
assert self.is_setup
|
895 |
+
|
896 |
+
for self.logger.epoch in range(
|
897 |
+
self.logger.epoch, self.logger.epoch + self.opts.train.epochs
|
898 |
+
):
|
899 |
+
# backprop painter's disc loss to masker
|
900 |
+
if (
|
901 |
+
self.logger.epoch == self.opts.gen.p.pl4m_epoch
|
902 |
+
and get_num_params(self.G.painter) > 0
|
903 |
+
and "p" in self.opts.tasks
|
904 |
+
and self.opts.gen.m.use_pl4m
|
905 |
+
):
|
906 |
+
print(
|
907 |
+
"\n\n >>> Enabling pl4m at epoch {}\n\n".format(self.logger.epoch)
|
908 |
+
)
|
909 |
+
self.use_pl4m = True
|
910 |
+
|
911 |
+
self.run_epoch()
|
912 |
+
self.run_evaluation(verbose=1)
|
913 |
+
self.save()
|
914 |
+
|
915 |
+
# end vkitti2 pre-training
|
916 |
+
if self.logger.epoch == self.opts.train.kitti.epochs - 1:
|
917 |
+
self.switch_data(to="base")
|
918 |
+
self.kitti_pretrain = False
|
919 |
+
|
920 |
+
# end pseudo training
|
921 |
+
if self.logger.epoch == self.opts.train.pseudo.epochs - 1:
|
922 |
+
self.pseudo_training_tasks = set()
|
923 |
+
|
924 |
+
def run_epoch(self):
|
925 |
+
"""Runs an epoch:
|
926 |
+
* checks trainer is setup
|
927 |
+
* gets a tuple of batches per domain
|
928 |
+
* sends batches to device
|
929 |
+
* updates sequentially G, D
|
930 |
+
"""
|
931 |
+
assert self.is_setup
|
932 |
+
self.train_mode()
|
933 |
+
if self.exp is not None:
|
934 |
+
self.exp.log_parameter("epoch", self.logger.epoch)
|
935 |
+
epoch_len = min(len(loader) for loader in self.loaders["train"].values())
|
936 |
+
epoch_desc = "Epoch {}".format(self.logger.epoch)
|
937 |
+
self.logger.time.epoch_start = time()
|
938 |
+
|
939 |
+
for multi_batch_tuple in tqdm(
|
940 |
+
self.train_loaders,
|
941 |
+
desc=epoch_desc,
|
942 |
+
total=epoch_len,
|
943 |
+
mininterval=0.5,
|
944 |
+
unit="batch",
|
945 |
+
):
|
946 |
+
|
947 |
+
self.logger.time.step_start = time()
|
948 |
+
multi_batch_tuple = shuffle_batch_tuple(multi_batch_tuple)
|
949 |
+
|
950 |
+
# The `[0]` is because the domain is contained in a list
|
951 |
+
multi_domain_batch = {
|
952 |
+
batch["domain"][0]: self.batch_to_device(batch)
|
953 |
+
for batch in multi_batch_tuple
|
954 |
+
}
|
955 |
+
# ------------------------------
|
956 |
+
# ----- Update Generator -----
|
957 |
+
# ------------------------------
|
958 |
+
|
959 |
+
# freeze params of the discriminator
|
960 |
+
if self.d_opt is not None:
|
961 |
+
for param in self.D.parameters():
|
962 |
+
param.requires_grad = False
|
963 |
+
|
964 |
+
self.update_G(multi_domain_batch)
|
965 |
+
|
966 |
+
# ----------------------------------
|
967 |
+
# ----- Update Discriminator -----
|
968 |
+
# ----------------------------------
|
969 |
+
|
970 |
+
# unfreeze params of the discriminator
|
971 |
+
if self.d_opt is not None and not self.kitti_pretrain:
|
972 |
+
for param in self.D.parameters():
|
973 |
+
param.requires_grad = True
|
974 |
+
|
975 |
+
self.update_D(multi_domain_batch)
|
976 |
+
|
977 |
+
# -------------------------
|
978 |
+
# ----- Log Metrics -----
|
979 |
+
# -------------------------
|
980 |
+
self.logger.global_step += 1
|
981 |
+
self.logger.log_step_time(time())
|
982 |
+
|
983 |
+
if not self.kitti_pretrain:
|
984 |
+
self.update_learning_rates()
|
985 |
+
|
986 |
+
self.logger.log_learning_rates()
|
987 |
+
self.logger.log_epoch_time(time())
|
988 |
+
|
989 |
+
def update_G(self, multi_domain_batch, verbose=0):
|
990 |
+
"""Perform an update on g from multi_domain_batch which is a dictionary
|
991 |
+
domain => batch
|
992 |
+
|
993 |
+
* automatic mixed precision according to self.opts.train.amp
|
994 |
+
* compute loss for each task
|
995 |
+
* loss.backward()
|
996 |
+
* g_opt_step()
|
997 |
+
* g_opt.step() or .extrapolation() depending on self.logger.global_step
|
998 |
+
* logs losses on comet.ml with self.logger.log_losses(model_to_update="G")
|
999 |
+
|
1000 |
+
Args:
|
1001 |
+
multi_domain_batch (dict): dictionnary of domain batches
|
1002 |
+
"""
|
1003 |
+
zero_grad(self.G)
|
1004 |
+
if self.opts.train.amp:
|
1005 |
+
with autocast():
|
1006 |
+
g_loss = self.get_G_loss(multi_domain_batch, verbose)
|
1007 |
+
self.grad_scaler_g.scale(g_loss).backward()
|
1008 |
+
self.grad_scaler_g.step(self.g_opt)
|
1009 |
+
self.grad_scaler_g.update()
|
1010 |
+
else:
|
1011 |
+
g_loss = self.get_G_loss(multi_domain_batch, verbose)
|
1012 |
+
g_loss.backward()
|
1013 |
+
self.g_opt_step()
|
1014 |
+
|
1015 |
+
self.logger.log_losses(model_to_update="G", mode="train")
|
1016 |
+
|
1017 |
+
def update_D(self, multi_domain_batch, verbose=0):
|
1018 |
+
zero_grad(self.D)
|
1019 |
+
|
1020 |
+
if self.opts.train.amp:
|
1021 |
+
with autocast():
|
1022 |
+
d_loss = self.get_D_loss(multi_domain_batch, verbose)
|
1023 |
+
self.grad_scaler_d.scale(d_loss).backward()
|
1024 |
+
self.grad_scaler_d.step(self.d_opt)
|
1025 |
+
self.grad_scaler_d.update()
|
1026 |
+
else:
|
1027 |
+
d_loss = self.get_D_loss(multi_domain_batch, verbose)
|
1028 |
+
d_loss.backward()
|
1029 |
+
self.d_opt_step()
|
1030 |
+
|
1031 |
+
self.logger.losses.disc.total_loss = d_loss.item()
|
1032 |
+
self.logger.log_losses(model_to_update="D", mode="train")
|
1033 |
+
|
1034 |
+
def get_D_loss(self, multi_domain_batch, verbose=0):
|
1035 |
+
"""Compute the discriminators' losses:
|
1036 |
+
|
1037 |
+
* for each domain-specific batch:
|
1038 |
+
* encode the image
|
1039 |
+
* get the conditioning tensor if using spade
|
1040 |
+
* source domain is the data's domain, sequentially r|s then f|n
|
1041 |
+
* get the target domain accordingly
|
1042 |
+
* compute the translated image from the data
|
1043 |
+
* compute the source domain discriminator's loss on the data
|
1044 |
+
* compute the target domain discriminator's loss on the translated image
|
1045 |
+
|
1046 |
+
# ? In this setting, each D[decoder][domain] is updated twice towards
|
1047 |
+
# real or fake data
|
1048 |
+
|
1049 |
+
See readme's update d section for details
|
1050 |
+
|
1051 |
+
Args:
|
1052 |
+
multi_domain_batch ([type]): [description]
|
1053 |
+
|
1054 |
+
Returns:
|
1055 |
+
[type]: [description]
|
1056 |
+
"""
|
1057 |
+
|
1058 |
+
disc_loss = {
|
1059 |
+
"m": {"Advent": 0},
|
1060 |
+
"s": {"Advent": 0},
|
1061 |
+
}
|
1062 |
+
if self.opts.dis.p.use_local_discriminator:
|
1063 |
+
disc_loss["p"] = {"global": 0, "local": 0}
|
1064 |
+
else:
|
1065 |
+
disc_loss["p"] = {"gan": 0}
|
1066 |
+
|
1067 |
+
for domain, batch in multi_domain_batch.items():
|
1068 |
+
x = batch["data"]["x"]
|
1069 |
+
|
1070 |
+
# ---------------------
|
1071 |
+
# ----- Painter -----
|
1072 |
+
# ---------------------
|
1073 |
+
if domain == "rf" and self.has_painter:
|
1074 |
+
m = batch["data"]["m"]
|
1075 |
+
# sample vector
|
1076 |
+
with torch.no_grad():
|
1077 |
+
# see spade compute_discriminator_loss
|
1078 |
+
fake = self.G.paint(m, x)
|
1079 |
+
if self.opts.gen.p.diff_aug.use:
|
1080 |
+
fake = self.diff_transforms(fake)
|
1081 |
+
x = self.diff_transforms(x)
|
1082 |
+
fake = fake.detach()
|
1083 |
+
fake.requires_grad_()
|
1084 |
+
|
1085 |
+
if self.opts.dis.p.use_local_discriminator:
|
1086 |
+
fake_d_global = self.D["p"]["global"](fake)
|
1087 |
+
real_d_global = self.D["p"]["global"](x)
|
1088 |
+
|
1089 |
+
fake_d_local = self.D["p"]["local"](fake * m)
|
1090 |
+
real_d_local = self.D["p"]["local"](x * m)
|
1091 |
+
|
1092 |
+
global_loss = self.losses["D"]["p"](fake_d_global, False, True)
|
1093 |
+
global_loss += self.losses["D"]["p"](real_d_global, True, True)
|
1094 |
+
|
1095 |
+
local_loss = self.losses["D"]["p"](fake_d_local, False, True)
|
1096 |
+
local_loss += self.losses["D"]["p"](real_d_local, True, True)
|
1097 |
+
|
1098 |
+
disc_loss["p"]["global"] += global_loss
|
1099 |
+
disc_loss["p"]["local"] += local_loss
|
1100 |
+
else:
|
1101 |
+
real_cat = torch.cat([m, x], axis=1)
|
1102 |
+
fake_cat = torch.cat([m, fake], axis=1)
|
1103 |
+
real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
|
1104 |
+
real_fake_d = self.D["p"](real_fake_cat)
|
1105 |
+
real_d, fake_d = divide_pred(real_fake_d)
|
1106 |
+
disc_loss["p"]["gan"] = self.losses["D"]["p"](fake_d, False, True)
|
1107 |
+
disc_loss["p"]["gan"] += self.losses["D"]["p"](real_d, True, True)
|
1108 |
+
|
1109 |
+
# --------------------
|
1110 |
+
# ----- Masker -----
|
1111 |
+
# --------------------
|
1112 |
+
else:
|
1113 |
+
z = self.G.encode(x)
|
1114 |
+
s_pred = d_pred = cond = z_depth = None
|
1115 |
+
|
1116 |
+
if "s" in batch["data"]:
|
1117 |
+
if "d" in self.opts.tasks and self.opts.gen.s.use_dada:
|
1118 |
+
d_pred, z_depth = self.G.decoders["d"](z)
|
1119 |
+
|
1120 |
+
step_loss, s_pred = self.masker_s_loss(
|
1121 |
+
x, z, d_pred, z_depth, None, domain, for_="D"
|
1122 |
+
)
|
1123 |
+
step_loss *= self.opts.train.lambdas.advent.adv_main
|
1124 |
+
disc_loss["s"]["Advent"] += step_loss
|
1125 |
+
|
1126 |
+
if "m" in batch["data"]:
|
1127 |
+
if "d" in self.opts.tasks:
|
1128 |
+
if self.opts.gen.m.use_spade:
|
1129 |
+
if d_pred is None:
|
1130 |
+
d_pred, z_depth = self.G.decoders["d"](z)
|
1131 |
+
cond = self.G.make_m_cond(d_pred, s_pred, x)
|
1132 |
+
elif self.opts.gen.m.use_dada:
|
1133 |
+
if d_pred is None:
|
1134 |
+
d_pred, z_depth = self.G.decoders["d"](z)
|
1135 |
+
|
1136 |
+
step_loss, _ = self.masker_m_loss(
|
1137 |
+
x,
|
1138 |
+
z,
|
1139 |
+
None,
|
1140 |
+
domain,
|
1141 |
+
for_="D",
|
1142 |
+
cond=cond,
|
1143 |
+
z_depth=z_depth,
|
1144 |
+
depth_preds=d_pred,
|
1145 |
+
)
|
1146 |
+
step_loss *= self.opts.train.lambdas.advent.adv_main
|
1147 |
+
disc_loss["m"]["Advent"] += step_loss
|
1148 |
+
|
1149 |
+
self.logger.losses.disc.update(
|
1150 |
+
{
|
1151 |
+
dom: {
|
1152 |
+
k: v.item() if isinstance(v, torch.Tensor) else v
|
1153 |
+
for k, v in d.items()
|
1154 |
+
}
|
1155 |
+
for dom, d in disc_loss.items()
|
1156 |
+
}
|
1157 |
+
)
|
1158 |
+
|
1159 |
+
loss = sum(v for d in disc_loss.values() for k, v in d.items())
|
1160 |
+
return loss
|
1161 |
+
|
1162 |
+
def get_G_loss(self, multi_domain_batch, verbose=0):
|
1163 |
+
m_loss = p_loss = None
|
1164 |
+
|
1165 |
+
# For now, always compute "representation loss"
|
1166 |
+
g_loss = 0
|
1167 |
+
|
1168 |
+
if any(t in self.opts.tasks for t in "msd"):
|
1169 |
+
m_loss = self.get_masker_loss(multi_domain_batch)
|
1170 |
+
self.logger.losses.gen.masker = m_loss.item()
|
1171 |
+
g_loss += m_loss
|
1172 |
+
|
1173 |
+
if "p" in self.opts.tasks and not self.kitti_pretrain:
|
1174 |
+
p_loss = self.get_painter_loss(multi_domain_batch)
|
1175 |
+
self.logger.losses.gen.painter = p_loss.item()
|
1176 |
+
g_loss += p_loss
|
1177 |
+
|
1178 |
+
assert g_loss != 0 and not isinstance(g_loss, int), "No update in get_G_loss!"
|
1179 |
+
|
1180 |
+
self.logger.losses.gen.total_loss = g_loss.item()
|
1181 |
+
|
1182 |
+
return g_loss
|
1183 |
+
|
1184 |
+
def get_masker_loss(self, multi_domain_batch): # TODO update docstrings
|
1185 |
+
"""Only update the representation part of the model, meaning everything
|
1186 |
+
but the translation part
|
1187 |
+
|
1188 |
+
* for each batch in available domains:
|
1189 |
+
* compute task-specific losses
|
1190 |
+
* compute the adaptation and translation decoders' auto-encoding losses
|
1191 |
+
* compute the adaptation decoder's translation losses (GAN and Cycle)
|
1192 |
+
|
1193 |
+
Args:
|
1194 |
+
multi_domain_batch (dict): dictionnary mapping domain names to batches from
|
1195 |
+
the trainer's loaders
|
1196 |
+
|
1197 |
+
Returns:
|
1198 |
+
torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas
|
1199 |
+
"""
|
1200 |
+
m_loss = 0
|
1201 |
+
for domain, batch in multi_domain_batch.items():
|
1202 |
+
# We don't care about the flooded domain here
|
1203 |
+
if domain == "rf":
|
1204 |
+
continue
|
1205 |
+
|
1206 |
+
x = batch["data"]["x"]
|
1207 |
+
z = self.G.encode(x)
|
1208 |
+
|
1209 |
+
# --------------------------------------
|
1210 |
+
# ----- task-specific losses (2) -----
|
1211 |
+
# --------------------------------------
|
1212 |
+
d_pred = s_pred = z_depth = None
|
1213 |
+
for task in ["d", "s", "m"]:
|
1214 |
+
if task not in batch["data"]:
|
1215 |
+
continue
|
1216 |
+
|
1217 |
+
target = batch["data"][task]
|
1218 |
+
|
1219 |
+
if task == "d":
|
1220 |
+
loss, d_pred, z_depth = self.masker_d_loss(
|
1221 |
+
x, z, target, domain, "G"
|
1222 |
+
)
|
1223 |
+
m_loss += loss
|
1224 |
+
self.logger.losses.gen.task["d"][domain] = loss.item()
|
1225 |
+
|
1226 |
+
elif task == "s":
|
1227 |
+
loss, s_pred = self.masker_s_loss(
|
1228 |
+
x, z, d_pred, z_depth, target, domain, "G"
|
1229 |
+
)
|
1230 |
+
m_loss += loss
|
1231 |
+
self.logger.losses.gen.task["s"][domain] = loss.item()
|
1232 |
+
|
1233 |
+
elif task == "m":
|
1234 |
+
cond = None
|
1235 |
+
if self.opts.gen.m.use_spade:
|
1236 |
+
if not self.opts.gen.m.detach:
|
1237 |
+
d_pred = d_pred.clone()
|
1238 |
+
s_pred = s_pred.clone()
|
1239 |
+
cond = self.G.make_m_cond(d_pred, s_pred, x)
|
1240 |
+
|
1241 |
+
loss, _ = self.masker_m_loss(
|
1242 |
+
x,
|
1243 |
+
z,
|
1244 |
+
target,
|
1245 |
+
domain,
|
1246 |
+
"G",
|
1247 |
+
cond=cond,
|
1248 |
+
z_depth=z_depth,
|
1249 |
+
depth_preds=d_pred,
|
1250 |
+
)
|
1251 |
+
m_loss += loss
|
1252 |
+
self.logger.losses.gen.task["m"][domain] = loss.item()
|
1253 |
+
|
1254 |
+
return m_loss
|
1255 |
+
|
1256 |
+
def get_painter_loss(self, multi_domain_batch):
|
1257 |
+
"""Computes the translation loss when flooding/deflooding images
|
1258 |
+
|
1259 |
+
Args:
|
1260 |
+
multi_domain_batch (dict): dictionnary mapping domain names to batches from
|
1261 |
+
the trainer's loaders
|
1262 |
+
|
1263 |
+
Returns:
|
1264 |
+
torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas
|
1265 |
+
"""
|
1266 |
+
step_loss = 0
|
1267 |
+
# self.g_opt.zero_grad()
|
1268 |
+
lambdas = self.opts.train.lambdas
|
1269 |
+
batch_domain = "rf"
|
1270 |
+
batch = multi_domain_batch[batch_domain]
|
1271 |
+
|
1272 |
+
x = batch["data"]["x"]
|
1273 |
+
# ! different mask: hides water to be reconstructed
|
1274 |
+
# ! 1 for water, 0 otherwise
|
1275 |
+
m = batch["data"]["m"]
|
1276 |
+
fake_flooded = self.G.paint(m, x)
|
1277 |
+
|
1278 |
+
# ----------------------
|
1279 |
+
# ----- VGG Loss -----
|
1280 |
+
# ----------------------
|
1281 |
+
if lambdas.G.p.vgg != 0:
|
1282 |
+
loss = self.losses["G"]["p"]["vgg"](
|
1283 |
+
vgg_preprocess(fake_flooded * m), vgg_preprocess(x * m)
|
1284 |
+
)
|
1285 |
+
loss *= lambdas.G.p.vgg
|
1286 |
+
self.logger.losses.gen.p.vgg = loss.item()
|
1287 |
+
step_loss += loss
|
1288 |
+
|
1289 |
+
# ---------------------
|
1290 |
+
# ----- TV Loss -----
|
1291 |
+
# ---------------------
|
1292 |
+
if lambdas.G.p.tv != 0:
|
1293 |
+
loss = self.losses["G"]["p"]["tv"](fake_flooded * m)
|
1294 |
+
loss *= lambdas.G.p.tv
|
1295 |
+
self.logger.losses.gen.p.tv = loss.item()
|
1296 |
+
step_loss += loss
|
1297 |
+
|
1298 |
+
# --------------------------
|
1299 |
+
# ----- Context Loss -----
|
1300 |
+
# --------------------------
|
1301 |
+
if lambdas.G.p.context != 0:
|
1302 |
+
loss = self.losses["G"]["p"]["context"](fake_flooded, x, m)
|
1303 |
+
loss *= lambdas.G.p.context
|
1304 |
+
self.logger.losses.gen.p.context = loss.item()
|
1305 |
+
step_loss += loss
|
1306 |
+
|
1307 |
+
# ---------------------------------
|
1308 |
+
# ----- Reconstruction Loss -----
|
1309 |
+
# ---------------------------------
|
1310 |
+
if lambdas.G.p.reconstruction != 0:
|
1311 |
+
loss = self.losses["G"]["p"]["reconstruction"](fake_flooded, x, m)
|
1312 |
+
loss *= lambdas.G.p.reconstruction
|
1313 |
+
self.logger.losses.gen.p.reconstruction = loss.item()
|
1314 |
+
step_loss += loss
|
1315 |
+
|
1316 |
+
# -------------------------------------
|
1317 |
+
# ----- Local & Global GAN Loss -----
|
1318 |
+
# -------------------------------------
|
1319 |
+
if self.opts.gen.p.diff_aug.use:
|
1320 |
+
fake_flooded = self.diff_transforms(fake_flooded)
|
1321 |
+
x = self.diff_transforms(x)
|
1322 |
+
|
1323 |
+
if self.opts.dis.p.use_local_discriminator:
|
1324 |
+
fake_d_global = self.D["p"]["global"](fake_flooded)
|
1325 |
+
fake_d_local = self.D["p"]["local"](fake_flooded * m)
|
1326 |
+
|
1327 |
+
real_d_global = self.D["p"]["global"](x)
|
1328 |
+
|
1329 |
+
# Note: discriminator returns [out_1,...,out_num_D] outputs
|
1330 |
+
# Each out_i is a list [feat1, feat2, ..., pred_i]
|
1331 |
+
|
1332 |
+
self.logger.losses.gen.p.gan = 0
|
1333 |
+
|
1334 |
+
loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False)
|
1335 |
+
loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False)
|
1336 |
+
loss *= lambdas.G["p"]["gan"]
|
1337 |
+
|
1338 |
+
self.logger.losses.gen.p.gan = loss.item()
|
1339 |
+
|
1340 |
+
step_loss += loss
|
1341 |
+
|
1342 |
+
# -----------------------------------
|
1343 |
+
# ----- Feature Matching Loss -----
|
1344 |
+
# -----------------------------------
|
1345 |
+
# (only on global discriminator)
|
1346 |
+
# Order must be real, fake
|
1347 |
+
if self.opts.dis.p.get_intermediate_features:
|
1348 |
+
loss = self.losses["G"]["p"]["featmatch"](real_d_global, fake_d_global)
|
1349 |
+
loss *= lambdas.G["p"]["featmatch"]
|
1350 |
+
|
1351 |
+
if isinstance(loss, float):
|
1352 |
+
self.logger.losses.gen.p.featmatch = loss
|
1353 |
+
else:
|
1354 |
+
self.logger.losses.gen.p.featmatch = loss.item()
|
1355 |
+
|
1356 |
+
step_loss += loss
|
1357 |
+
|
1358 |
+
# -------------------------------------------
|
1359 |
+
# ----- Single Discriminator GAN Loss -----
|
1360 |
+
# -------------------------------------------
|
1361 |
+
else:
|
1362 |
+
real_cat = torch.cat([m, x], axis=1)
|
1363 |
+
fake_cat = torch.cat([m, fake_flooded], axis=1)
|
1364 |
+
real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
|
1365 |
+
|
1366 |
+
real_fake_d = self.D["p"](real_fake_cat)
|
1367 |
+
real_d, fake_d = divide_pred(real_fake_d)
|
1368 |
+
|
1369 |
+
loss = self.losses["G"]["p"]["gan"](fake_d, True, False)
|
1370 |
+
self.logger.losses.gen.p.gan = loss.item()
|
1371 |
+
step_loss += loss
|
1372 |
+
|
1373 |
+
# -----------------------------------
|
1374 |
+
# ----- Feature Matching Loss -----
|
1375 |
+
# -----------------------------------
|
1376 |
+
if self.opts.dis.p.get_intermediate_features and lambdas.G.p.featmatch != 0:
|
1377 |
+
loss = self.losses["G"]["p"]["featmatch"](real_d, fake_d)
|
1378 |
+
loss *= lambdas.G.p.featmatch
|
1379 |
+
|
1380 |
+
if isinstance(loss, float):
|
1381 |
+
self.logger.losses.gen.p.featmatch = loss
|
1382 |
+
else:
|
1383 |
+
self.logger.losses.gen.p.featmatch = loss.item()
|
1384 |
+
|
1385 |
+
step_loss += loss
|
1386 |
+
|
1387 |
+
return step_loss
|
1388 |
+
|
1389 |
+
def masker_d_loss(self, x, z, target, domain, for_="G"):
|
1390 |
+
assert for_ in {"G", "D"}
|
1391 |
+
self.assert_z_matches_x(x, z)
|
1392 |
+
assert x.shape[0] == target.shape[0]
|
1393 |
+
zero_loss = torch.tensor(0.0, device=self.device)
|
1394 |
+
weight = self.opts.train.lambdas.G.d.main
|
1395 |
+
|
1396 |
+
prediction, z_depth = self.G.decoders["d"](z)
|
1397 |
+
|
1398 |
+
if self.opts.gen.d.classify.enable:
|
1399 |
+
target.squeeze_(1)
|
1400 |
+
|
1401 |
+
full_loss = self.losses["G"]["tasks"]["d"](prediction, target)
|
1402 |
+
full_loss *= weight
|
1403 |
+
|
1404 |
+
if weight == 0 or (domain == "r" and "d" not in self.pseudo_training_tasks):
|
1405 |
+
return zero_loss, prediction, z_depth
|
1406 |
+
|
1407 |
+
return full_loss, prediction, z_depth
|
1408 |
+
|
1409 |
+
def masker_s_loss(self, x, z, depth_preds, z_depth, target, domain, for_="G"):
|
1410 |
+
assert for_ in {"G", "D"}
|
1411 |
+
assert domain in {"r", "s"}
|
1412 |
+
self.assert_z_matches_x(x, z)
|
1413 |
+
assert x.shape[0] == target.shape[0] if target is not None else True
|
1414 |
+
full_loss = torch.tensor(0.0, device=self.device)
|
1415 |
+
softmax_preds = None
|
1416 |
+
# --------------------------
|
1417 |
+
# ----- Segmentation -----
|
1418 |
+
# --------------------------
|
1419 |
+
pred = None
|
1420 |
+
if for_ == "G" or self.opts.gen.s.use_advent:
|
1421 |
+
pred = self.G.decoders["s"](z, z_depth)
|
1422 |
+
|
1423 |
+
# Supervised segmentation loss: crossent for sim domain,
|
1424 |
+
# crossent_pseudo for real ; loss is crossent in any case
|
1425 |
+
if for_ == "G":
|
1426 |
+
if domain == "s" or "s" in self.pseudo_training_tasks:
|
1427 |
+
if domain == "s":
|
1428 |
+
logger = self.logger.losses.gen.task["s"]["crossent"]
|
1429 |
+
weight = self.opts.train.lambdas.G["s"]["crossent"]
|
1430 |
+
else:
|
1431 |
+
logger = self.logger.losses.gen.task["s"]["crossent_pseudo"]
|
1432 |
+
weight = self.opts.train.lambdas.G["s"]["crossent_pseudo"]
|
1433 |
+
|
1434 |
+
if weight != 0:
|
1435 |
+
# Cross-Entropy loss
|
1436 |
+
loss_func = self.losses["G"]["tasks"]["s"]["crossent"]
|
1437 |
+
loss = loss_func(pred, target.squeeze(1))
|
1438 |
+
loss *= weight
|
1439 |
+
full_loss += loss
|
1440 |
+
logger[domain] = loss.item()
|
1441 |
+
|
1442 |
+
if domain == "r":
|
1443 |
+
weight = self.opts.train.lambdas.G["s"]["minent"]
|
1444 |
+
if self.opts.gen.s.use_minent and weight != 0:
|
1445 |
+
softmax_preds = softmax(pred, dim=1)
|
1446 |
+
# Entropy minimization loss
|
1447 |
+
loss = self.losses["G"]["tasks"]["s"]["minent"](softmax_preds)
|
1448 |
+
loss *= weight
|
1449 |
+
full_loss += loss
|
1450 |
+
|
1451 |
+
self.logger.losses.gen.task["s"]["minent"]["r"] = loss.item()
|
1452 |
+
|
1453 |
+
# Fool ADVENT discriminator
|
1454 |
+
if self.opts.gen.s.use_advent:
|
1455 |
+
if self.opts.gen.s.use_dada and depth_preds is not None:
|
1456 |
+
depth_preds = depth_preds.detach()
|
1457 |
+
else:
|
1458 |
+
depth_preds = None
|
1459 |
+
|
1460 |
+
if for_ == "D":
|
1461 |
+
domain_label = domain
|
1462 |
+
logger = {}
|
1463 |
+
loss_func = self.losses["D"]["advent"]
|
1464 |
+
pred = pred.detach()
|
1465 |
+
weight = self.opts.train.lambdas.advent.adv_main
|
1466 |
+
else:
|
1467 |
+
domain_label = "s"
|
1468 |
+
logger = self.logger.losses.gen.task["s"]["advent"]
|
1469 |
+
loss_func = self.losses["G"]["tasks"]["s"]["advent"]
|
1470 |
+
weight = self.opts.train.lambdas.G["s"]["advent"]
|
1471 |
+
|
1472 |
+
if (for_ == "D" or domain == "r") and weight != 0:
|
1473 |
+
if softmax_preds is None:
|
1474 |
+
softmax_preds = softmax(pred, dim=1)
|
1475 |
+
loss = loss_func(
|
1476 |
+
softmax_preds,
|
1477 |
+
self.domain_labels[domain_label],
|
1478 |
+
self.D["s"]["Advent"],
|
1479 |
+
depth_preds,
|
1480 |
+
)
|
1481 |
+
loss *= weight
|
1482 |
+
full_loss += loss
|
1483 |
+
logger[domain] = loss.item()
|
1484 |
+
|
1485 |
+
if for_ == "D":
|
1486 |
+
# WGAN: clipping or GP
|
1487 |
+
if self.opts.dis.s.gan_type == "GAN" or "WGAN_norm":
|
1488 |
+
pass
|
1489 |
+
elif self.opts.dis.s.gan_type == "WGAN":
|
1490 |
+
for p in self.D["s"]["Advent"].parameters():
|
1491 |
+
p.data.clamp_(
|
1492 |
+
self.opts.dis.s.wgan_clamp_lower,
|
1493 |
+
self.opts.dis.s.wgan_clamp_upper,
|
1494 |
+
)
|
1495 |
+
elif self.opts.dis.s.gan_type == "WGAN_gp":
|
1496 |
+
prob_need_grad = autograd.Variable(pred, requires_grad=True)
|
1497 |
+
d_out = self.D["s"]["Advent"](prob_need_grad)
|
1498 |
+
gp = get_WGAN_gradient(prob_need_grad, d_out)
|
1499 |
+
gp_loss = gp * self.opts.train.lambdas.advent.WGAN_gp
|
1500 |
+
full_loss += gp_loss
|
1501 |
+
else:
|
1502 |
+
raise NotImplementedError
|
1503 |
+
|
1504 |
+
return full_loss, pred
|
1505 |
+
|
1506 |
+
def masker_m_loss(
|
1507 |
+
self, x, z, target, domain, for_="G", cond=None, z_depth=None, depth_preds=None
|
1508 |
+
):
|
1509 |
+
assert for_ in {"G", "D"}
|
1510 |
+
assert domain in {"r", "s"}
|
1511 |
+
self.assert_z_matches_x(x, z)
|
1512 |
+
assert x.shape[0] == target.shape[0] if target is not None else True
|
1513 |
+
full_loss = torch.tensor(0.0, device=self.device)
|
1514 |
+
|
1515 |
+
pred_logits = self.G.decoders["m"](z, cond=cond, z_depth=z_depth)
|
1516 |
+
pred_prob = sigmoid(pred_logits)
|
1517 |
+
pred_prob_complementary = 1 - pred_prob
|
1518 |
+
prob = torch.cat([pred_prob, pred_prob_complementary], dim=1)
|
1519 |
+
|
1520 |
+
if for_ == "G":
|
1521 |
+
# TV loss
|
1522 |
+
weight = self.opts.train.lambdas.G.m.tv
|
1523 |
+
if weight != 0:
|
1524 |
+
loss = self.losses["G"]["tasks"]["m"]["tv"](pred_prob)
|
1525 |
+
loss *= weight
|
1526 |
+
full_loss += loss
|
1527 |
+
|
1528 |
+
self.logger.losses.gen.task["m"]["tv"][domain] = loss.item()
|
1529 |
+
|
1530 |
+
weight = self.opts.train.lambdas.G.m.bce
|
1531 |
+
if domain == "s" and weight != 0:
|
1532 |
+
# CrossEnt Loss
|
1533 |
+
loss = self.losses["G"]["tasks"]["m"]["bce"](pred_logits, target)
|
1534 |
+
loss *= weight
|
1535 |
+
full_loss += loss
|
1536 |
+
self.logger.losses.gen.task["m"]["bce"]["s"] = loss.item()
|
1537 |
+
|
1538 |
+
if domain == "r":
|
1539 |
+
|
1540 |
+
weight = self.opts.train.lambdas.G["m"]["gi"]
|
1541 |
+
if self.opts.gen.m.use_ground_intersection and weight != 0:
|
1542 |
+
# GroundIntersection loss
|
1543 |
+
loss = self.losses["G"]["tasks"]["m"]["gi"](pred_prob, target)
|
1544 |
+
loss *= weight
|
1545 |
+
full_loss += loss
|
1546 |
+
self.logger.losses.gen.task["m"]["gi"]["r"] = loss.item()
|
1547 |
+
|
1548 |
+
weight = self.opts.train.lambdas.G.m.pl4m
|
1549 |
+
if self.use_pl4m and weight != 0:
|
1550 |
+
# Painter loss
|
1551 |
+
pl4m_loss = self.painter_loss_for_masker(x, pred_prob)
|
1552 |
+
pl4m_loss *= weight
|
1553 |
+
full_loss += pl4m_loss
|
1554 |
+
self.logger.losses.gen.task.m.pl4m.r = pl4m_loss.item()
|
1555 |
+
|
1556 |
+
weight = self.opts.train.lambdas.advent.ent_main
|
1557 |
+
if self.opts.gen.m.use_minent and weight != 0:
|
1558 |
+
# MinEnt loss
|
1559 |
+
loss = self.losses["G"]["tasks"]["m"]["minent"](prob)
|
1560 |
+
loss *= weight
|
1561 |
+
full_loss += loss
|
1562 |
+
self.logger.losses.gen.task["m"]["minent"]["r"] = loss.item()
|
1563 |
+
|
1564 |
+
if self.opts.gen.m.use_advent:
|
1565 |
+
# AdvEnt loss
|
1566 |
+
if self.opts.gen.m.use_dada and depth_preds is not None:
|
1567 |
+
depth_preds = depth_preds.detach()
|
1568 |
+
depth_preds = torch.nn.functional.interpolate(
|
1569 |
+
depth_preds, size=x.shape[-2:], mode="nearest"
|
1570 |
+
)
|
1571 |
+
else:
|
1572 |
+
depth_preds = None
|
1573 |
+
|
1574 |
+
if for_ == "D":
|
1575 |
+
domain_label = domain
|
1576 |
+
logger = {}
|
1577 |
+
loss_func = self.losses["D"]["advent"]
|
1578 |
+
prob = prob.detach()
|
1579 |
+
weight = self.opts.train.lambdas.advent.adv_main
|
1580 |
+
else:
|
1581 |
+
domain_label = "s"
|
1582 |
+
logger = self.logger.losses.gen.task["m"]["advent"]
|
1583 |
+
loss_func = self.losses["G"]["tasks"]["m"]["advent"]
|
1584 |
+
weight = self.opts.train.lambdas.advent.adv_main
|
1585 |
+
|
1586 |
+
if (for_ == "D" or domain == "r") and weight != 0:
|
1587 |
+
loss = loss_func(
|
1588 |
+
prob.to(self.device),
|
1589 |
+
self.domain_labels[domain_label],
|
1590 |
+
self.D["m"]["Advent"],
|
1591 |
+
depth_preds,
|
1592 |
+
)
|
1593 |
+
loss *= weight
|
1594 |
+
full_loss += loss
|
1595 |
+
logger[domain] = loss.item()
|
1596 |
+
|
1597 |
+
if for_ == "D":
|
1598 |
+
# WGAN: clipping or GP
|
1599 |
+
if self.opts.dis.m.gan_type == "GAN" or "WGAN_norm":
|
1600 |
+
pass
|
1601 |
+
elif self.opts.dis.m.gan_type == "WGAN":
|
1602 |
+
for p in self.D["s"]["Advent"].parameters():
|
1603 |
+
p.data.clamp_(
|
1604 |
+
self.opts.dis.m.wgan_clamp_lower,
|
1605 |
+
self.opts.dis.m.wgan_clamp_upper,
|
1606 |
+
)
|
1607 |
+
elif self.opts.dis.m.gan_type == "WGAN_gp":
|
1608 |
+
prob_need_grad = autograd.Variable(prob, requires_grad=True)
|
1609 |
+
d_out = self.D["s"]["Advent"](prob_need_grad)
|
1610 |
+
gp = get_WGAN_gradient(prob_need_grad, d_out)
|
1611 |
+
gp_loss = self.opts.train.lambdas.advent.WGAN_gp * gp
|
1612 |
+
full_loss += gp_loss
|
1613 |
+
else:
|
1614 |
+
raise NotImplementedError
|
1615 |
+
|
1616 |
+
return full_loss, prob
|
1617 |
+
|
1618 |
+
def painter_loss_for_masker(self, x, m):
|
1619 |
+
# pl4m loss
|
1620 |
+
# painter should not be updated
|
1621 |
+
for param in self.G.painter.parameters():
|
1622 |
+
param.requires_grad = False
|
1623 |
+
# TODO for param in self.D.painter.parameters():
|
1624 |
+
# param.requires_grad = False
|
1625 |
+
|
1626 |
+
fake_flooded = self.G.paint(m, x)
|
1627 |
+
|
1628 |
+
if self.opts.dis.p.use_local_discriminator:
|
1629 |
+
fake_d_global = self.D["p"]["global"](fake_flooded)
|
1630 |
+
fake_d_local = self.D["p"]["local"](fake_flooded * m)
|
1631 |
+
|
1632 |
+
# Note: discriminator returns [out_1,...,out_num_D] outputs
|
1633 |
+
# Each out_i is a list [feat1, feat2, ..., pred_i]
|
1634 |
+
|
1635 |
+
pl4m_loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False)
|
1636 |
+
pl4m_loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False)
|
1637 |
+
else:
|
1638 |
+
real_cat = torch.cat([m, x], axis=1)
|
1639 |
+
fake_cat = torch.cat([m, fake_flooded], axis=1)
|
1640 |
+
real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
|
1641 |
+
|
1642 |
+
real_fake_d = self.D["p"](real_fake_cat)
|
1643 |
+
_, fake_d = divide_pred(real_fake_d)
|
1644 |
+
|
1645 |
+
pl4m_loss = self.losses["G"]["p"]["gan"](fake_d, True, False)
|
1646 |
+
|
1647 |
+
if "p" in self.opts.tasks:
|
1648 |
+
for param in self.G.painter.parameters():
|
1649 |
+
param.requires_grad = True
|
1650 |
+
|
1651 |
+
return pl4m_loss
|
1652 |
+
|
1653 |
+
@torch.no_grad()
|
1654 |
+
def run_evaluation(self, verbose=0):
|
1655 |
+
print("******************* Running Evaluation ***********************")
|
1656 |
+
start_time = time()
|
1657 |
+
self.eval_mode()
|
1658 |
+
val_logger = None
|
1659 |
+
nb_of_batches = None
|
1660 |
+
for i, multi_batch_tuple in enumerate(self.val_loaders):
|
1661 |
+
# create a dictionnary (domain => batch) from tuple
|
1662 |
+
# (batch_domain_0, ..., batch_domain_i)
|
1663 |
+
# and send it to self.device
|
1664 |
+
nb_of_batches = i + 1
|
1665 |
+
multi_domain_batch = {
|
1666 |
+
batch["domain"][0]: self.batch_to_device(batch)
|
1667 |
+
for batch in multi_batch_tuple
|
1668 |
+
}
|
1669 |
+
self.get_G_loss(multi_domain_batch, verbose)
|
1670 |
+
|
1671 |
+
if val_logger is None:
|
1672 |
+
val_logger = deepcopy(self.logger.losses.generator)
|
1673 |
+
else:
|
1674 |
+
val_logger = sum_dict(val_logger, self.logger.losses.generator)
|
1675 |
+
|
1676 |
+
val_logger = div_dict(val_logger, nb_of_batches)
|
1677 |
+
self.logger.losses.generator = val_logger
|
1678 |
+
self.logger.log_losses(model_to_update="G", mode="val")
|
1679 |
+
|
1680 |
+
for d in self.opts.domains:
|
1681 |
+
self.logger.log_comet_images("train", d)
|
1682 |
+
self.logger.log_comet_images("val", d)
|
1683 |
+
|
1684 |
+
if "m" in self.opts.tasks and self.has_painter and not self.kitti_pretrain:
|
1685 |
+
self.logger.log_comet_combined_images("train", "r")
|
1686 |
+
self.logger.log_comet_combined_images("val", "r")
|
1687 |
+
|
1688 |
+
if self.exp is not None:
|
1689 |
+
print()
|
1690 |
+
|
1691 |
+
if "m" in self.opts.tasks or "s" in self.opts.tasks:
|
1692 |
+
self.eval_images("val", "r")
|
1693 |
+
self.eval_images("val", "s")
|
1694 |
+
|
1695 |
+
if "p" in self.opts.tasks and not self.kitti_pretrain:
|
1696 |
+
val_fid = compute_val_fid(self)
|
1697 |
+
if self.exp is not None:
|
1698 |
+
self.exp.log_metric("val_fid", val_fid, step=self.logger.global_step)
|
1699 |
+
else:
|
1700 |
+
print("Validation FID Score", val_fid)
|
1701 |
+
|
1702 |
+
self.train_mode()
|
1703 |
+
timing = int(time() - start_time)
|
1704 |
+
print("****************** Done in {}s *********************".format(timing))
|
1705 |
+
|
1706 |
+
def eval_images(self, mode, domain):
|
1707 |
+
if domain == "s" and self.kitti_pretrain:
|
1708 |
+
domain = "kitti"
|
1709 |
+
if domain == "rf" or domain not in self.display_images[mode]:
|
1710 |
+
return
|
1711 |
+
|
1712 |
+
metric_funcs = {"accuracy": accuracy, "mIOU": mIOU}
|
1713 |
+
metric_avg_scores = {"m": {}}
|
1714 |
+
if "s" in self.opts.tasks:
|
1715 |
+
metric_avg_scores["s"] = {}
|
1716 |
+
if "d" in self.opts.tasks and domain == "s" and self.opts.gen.d.classify.enable:
|
1717 |
+
metric_avg_scores["d"] = {}
|
1718 |
+
|
1719 |
+
for key in metric_funcs:
|
1720 |
+
for task in metric_avg_scores:
|
1721 |
+
metric_avg_scores[task][key] = []
|
1722 |
+
|
1723 |
+
for im_set in self.display_images[mode][domain]:
|
1724 |
+
x = im_set["data"]["x"].unsqueeze(0).to(self.device)
|
1725 |
+
z = self.G.encode(x)
|
1726 |
+
|
1727 |
+
s_pred = d_pred = z_depth = None
|
1728 |
+
|
1729 |
+
if "d" in metric_avg_scores:
|
1730 |
+
d_pred, z_depth = self.G.decoders["d"](z)
|
1731 |
+
d_pred = d_pred.detach().cpu()
|
1732 |
+
|
1733 |
+
if domain == "s":
|
1734 |
+
d = im_set["data"]["d"].unsqueeze(0).detach()
|
1735 |
+
|
1736 |
+
for metric in metric_funcs:
|
1737 |
+
metric_score = metric_funcs[metric](d_pred, d)
|
1738 |
+
metric_avg_scores["d"][metric].append(metric_score)
|
1739 |
+
|
1740 |
+
if "s" in metric_avg_scores:
|
1741 |
+
if z_depth is None:
|
1742 |
+
if self.opts.gen.s.use_dada and "d" in self.opts.tasks:
|
1743 |
+
_, z_depth = self.G.decoders["d"](z)
|
1744 |
+
s_pred = self.G.decoders["s"](z, z_depth).detach().cpu()
|
1745 |
+
s = im_set["data"]["s"].unsqueeze(0).detach()
|
1746 |
+
|
1747 |
+
for metric in metric_funcs:
|
1748 |
+
metric_score = metric_funcs[metric](s_pred, s)
|
1749 |
+
metric_avg_scores["s"][metric].append(metric_score)
|
1750 |
+
|
1751 |
+
if "m" in self.opts:
|
1752 |
+
cond = None
|
1753 |
+
if s_pred is not None and d_pred is not None:
|
1754 |
+
cond = self.G.make_m_cond(d_pred, s_pred, x)
|
1755 |
+
if z_depth is None:
|
1756 |
+
if self.opts.gen.m.use_dada and "d" in self.opts.tasks:
|
1757 |
+
_, z_depth = self.G.decoders["d"](z)
|
1758 |
+
|
1759 |
+
pred_mask = (
|
1760 |
+
(self.G.mask(z=z, cond=cond, z_depth=z_depth)).detach().cpu()
|
1761 |
+
)
|
1762 |
+
pred_mask = (pred_mask > 0.5).to(torch.float32)
|
1763 |
+
pred_prob = torch.cat([1 - pred_mask, pred_mask], dim=1)
|
1764 |
+
|
1765 |
+
m = im_set["data"]["m"].unsqueeze(0).detach()
|
1766 |
+
|
1767 |
+
for metric in metric_funcs:
|
1768 |
+
if metric != "mIOU":
|
1769 |
+
metric_score = metric_funcs[metric](pred_mask, m)
|
1770 |
+
else:
|
1771 |
+
metric_score = metric_funcs[metric](pred_prob, m)
|
1772 |
+
|
1773 |
+
metric_avg_scores["m"][metric].append(metric_score)
|
1774 |
+
|
1775 |
+
metric_avg_scores = {
|
1776 |
+
task: {
|
1777 |
+
metric: np.mean(values) if values else float("nan")
|
1778 |
+
for metric, values in met_dict.items()
|
1779 |
+
}
|
1780 |
+
for task, met_dict in metric_avg_scores.items()
|
1781 |
+
}
|
1782 |
+
metric_avg_scores = {
|
1783 |
+
task: {
|
1784 |
+
metric: value if not np.isnan(value) else -1
|
1785 |
+
for metric, value in met_dict.items()
|
1786 |
+
}
|
1787 |
+
for task, met_dict in metric_avg_scores.items()
|
1788 |
+
}
|
1789 |
+
if self.exp is not None:
|
1790 |
+
self.exp.log_metrics(
|
1791 |
+
flatten_opts(metric_avg_scores),
|
1792 |
+
prefix=f"metrics_{mode}_{domain}",
|
1793 |
+
step=self.logger.global_step,
|
1794 |
+
)
|
1795 |
+
else:
|
1796 |
+
print(f"metrics_{mode}_{domain}")
|
1797 |
+
print(flatten_opts(metric_avg_scores))
|
1798 |
+
|
1799 |
+
return 0
|
1800 |
+
|
1801 |
+
def functional_test_mode(self):
|
1802 |
+
import atexit
|
1803 |
+
|
1804 |
+
self.opts.output_path = (
|
1805 |
+
Path("~").expanduser() / "climategan" / "functional_tests"
|
1806 |
+
)
|
1807 |
+
Path(self.opts.output_path).mkdir(parents=True, exist_ok=True)
|
1808 |
+
with open(Path(self.opts.output_path) / "is_functional.test", "w") as f:
|
1809 |
+
f.write("trainer functional test - delete this dir")
|
1810 |
+
|
1811 |
+
if self.exp is not None:
|
1812 |
+
self.exp.log_parameter("is_functional_test", True)
|
1813 |
+
atexit.register(self.del_output_path)
|
1814 |
+
|
1815 |
+
def del_output_path(self, force=False):
|
1816 |
+
import shutil
|
1817 |
+
|
1818 |
+
if not Path(self.opts.output_path).exists():
|
1819 |
+
return
|
1820 |
+
|
1821 |
+
if (Path(self.opts.output_path) / "is_functional.test").exists() or force:
|
1822 |
+
shutil.rmtree(self.opts.output_path)
|
1823 |
+
|
1824 |
+
def compute_fire(self, x, seg_preds=None, z=None, z_depth=None):
|
1825 |
+
"""
|
1826 |
+
Transforms input tensor given wildfires event
|
1827 |
+
Args:
|
1828 |
+
x (torch.Tensor): Input tensor
|
1829 |
+
seg_preds (torch.Tensor): Semantic segmentation
|
1830 |
+
predictions for input tensor
|
1831 |
+
z (torch.Tensor): Latent vector of encoded "x".
|
1832 |
+
Can be None if seg_preds is given.
|
1833 |
+
Returns:
|
1834 |
+
torch.Tensor: Wildfire version of input tensor
|
1835 |
+
"""
|
1836 |
+
|
1837 |
+
if seg_preds is None:
|
1838 |
+
if z is None:
|
1839 |
+
z = self.G.encode(x)
|
1840 |
+
seg_preds = self.G.decoders["s"](z, z_depth)
|
1841 |
+
|
1842 |
+
return add_fire(x, seg_preds, self.opts.events.fire)
|
1843 |
+
|
1844 |
+
def compute_flood(
|
1845 |
+
self, x, z=None, z_depth=None, m=None, s=None, cloudy=None, bin_value=-1
|
1846 |
+
):
|
1847 |
+
"""
|
1848 |
+
Applies a flood (mask + paint) to an input image, with optionally
|
1849 |
+
pre-computed masker z or mask
|
1850 |
+
|
1851 |
+
Args:
|
1852 |
+
x (torch.Tensor): B x C x H x W -1:1 input image
|
1853 |
+
z (torch.Tensor, optional): B x C x H x W Masker latent vector.
|
1854 |
+
Defaults to None.
|
1855 |
+
m (torch.Tensor, optional): B x 1 x H x W Mask. Defaults to None.
|
1856 |
+
bin_value (float, optional): Mask binarization value.
|
1857 |
+
Set to -1 to use smooth masks (no binarization)
|
1858 |
+
|
1859 |
+
Returns:
|
1860 |
+
torch.Tensor: B x 3 x H x W -1:1 flooded image
|
1861 |
+
"""
|
1862 |
+
|
1863 |
+
if m is None:
|
1864 |
+
if z is None:
|
1865 |
+
z = self.G.encode(x)
|
1866 |
+
if "d" in self.opts.tasks and self.opts.gen.m.use_dada and z_depth is None:
|
1867 |
+
_, z_depth = self.G.decoders["d"](z)
|
1868 |
+
m = self.G.mask(x=x, z=z, z_depth=z_depth)
|
1869 |
+
|
1870 |
+
if bin_value >= 0:
|
1871 |
+
m = (m > bin_value).to(m.dtype)
|
1872 |
+
|
1873 |
+
if cloudy:
|
1874 |
+
assert s is not None
|
1875 |
+
return self.G.paint_cloudy(m, x, s)
|
1876 |
+
|
1877 |
+
return self.G.paint(m, x)
|
1878 |
+
|
1879 |
+
def compute_smog(self, x, z=None, d=None, s=None, use_sky_seg=False):
|
1880 |
+
# implementation from the paper:
|
1881 |
+
# HazeRD: An outdoor scene dataset and benchmark for single image dehazing
|
1882 |
+
sky_mask = None
|
1883 |
+
if d is None or (use_sky_seg and s is None):
|
1884 |
+
if z is None:
|
1885 |
+
z = self.G.encode(x)
|
1886 |
+
if d is None:
|
1887 |
+
d, _ = self.G.decoders["d"](z)
|
1888 |
+
if use_sky_seg and s is None:
|
1889 |
+
if "s" not in self.opts.tasks:
|
1890 |
+
raise ValueError(
|
1891 |
+
"Cannot have "
|
1892 |
+
+ "(use_sky_seg is True and s is None and 's' not in tasks)"
|
1893 |
+
)
|
1894 |
+
s = self.G.decoders["s"](z)
|
1895 |
+
# TODO: s to sky mask
|
1896 |
+
# TODO: interpolate to d's size
|
1897 |
+
|
1898 |
+
params = self.opts.events.smog
|
1899 |
+
|
1900 |
+
airlight = params.airlight * torch.ones(3)
|
1901 |
+
airlight = airlight.view(1, -1, 1, 1).to(self.device)
|
1902 |
+
|
1903 |
+
irradiance = srgb2lrgb(x)
|
1904 |
+
|
1905 |
+
beta = torch.tensor([params.beta / params.vr] * 3)
|
1906 |
+
beta = beta.view(1, -1, 1, 1).to(self.device)
|
1907 |
+
|
1908 |
+
d = normalize(d, mini=0.3, maxi=1.0)
|
1909 |
+
d = 1.0 / d
|
1910 |
+
d = normalize(d, mini=0.1, maxi=1)
|
1911 |
+
|
1912 |
+
if sky_mask is not None:
|
1913 |
+
d[sky_mask] = 1
|
1914 |
+
|
1915 |
+
d = torch.nn.functional.interpolate(
|
1916 |
+
d, size=x.shape[-2:], mode="bilinear", align_corners=True
|
1917 |
+
)
|
1918 |
+
|
1919 |
+
d = d.repeat(1, 3, 1, 1)
|
1920 |
+
|
1921 |
+
transmission = torch.exp(d * -beta)
|
1922 |
+
|
1923 |
+
smogged = transmission * irradiance + (1 - transmission) * airlight
|
1924 |
+
|
1925 |
+
smogged = lrgb2srgb(smogged)
|
1926 |
+
|
1927 |
+
# add yellow filter
|
1928 |
+
alpha = params.alpha / 255
|
1929 |
+
yellow_mask = torch.Tensor([params.yellow_color]) / 255
|
1930 |
+
yellow_filter = (
|
1931 |
+
yellow_mask.unsqueeze(2)
|
1932 |
+
.unsqueeze(2)
|
1933 |
+
.repeat(1, 1, smogged.shape[-2], smogged.shape[-1])
|
1934 |
+
.to(self.device)
|
1935 |
+
)
|
1936 |
+
|
1937 |
+
smogged = smogged * (1 - alpha) + yellow_filter * alpha
|
1938 |
+
|
1939 |
+
return smogged
|
climategan/transforms.py
ADDED
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Data transforms for the loaders
|
2 |
+
"""
|
3 |
+
import random
|
4 |
+
import traceback
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from skimage.color import rgba2rgb
|
11 |
+
from skimage.io import imread
|
12 |
+
from torchvision import transforms as trsfs
|
13 |
+
from torchvision.transforms.functional import (
|
14 |
+
adjust_brightness,
|
15 |
+
adjust_contrast,
|
16 |
+
adjust_saturation,
|
17 |
+
)
|
18 |
+
|
19 |
+
from climategan.tutils import normalize
|
20 |
+
|
21 |
+
|
22 |
+
def interpolation(task):
|
23 |
+
if task in ["d", "m", "s"]:
|
24 |
+
return {"mode": "nearest"}
|
25 |
+
else:
|
26 |
+
return {"mode": "bilinear", "align_corners": True}
|
27 |
+
|
28 |
+
|
29 |
+
class Resize:
|
30 |
+
def __init__(self, target_size, keep_aspect_ratio=False):
|
31 |
+
"""
|
32 |
+
Resize transform. Target_size can be an int or a tuple of ints,
|
33 |
+
depending on whether both height and width should have the same
|
34 |
+
final size or not.
|
35 |
+
|
36 |
+
If keep_aspect_ratio is specified then target_size must be an int:
|
37 |
+
the smallest dimension of x will be set to target_size and the largest
|
38 |
+
dimension will be computed to the closest int keeping the original
|
39 |
+
aspect ratio. e.g.
|
40 |
+
>>> x = torch.rand(1, 3, 1200, 1800)
|
41 |
+
>>> m = torch.rand(1, 1, 600, 600)
|
42 |
+
>>> d = {"x": x, "m": m}
|
43 |
+
>>> {k: v.shape for k, v in Resize(640, True)(d).items()}
|
44 |
+
{"x": (1, 3, 640, 960), "m": (1, 1, 640, 960)}
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
Args:
|
49 |
+
target_size (int | tuple(int)): New size for the tensor
|
50 |
+
keep_aspect_ratio (bool, optional): Whether or not to keep aspect ratio
|
51 |
+
when resizing. Requires target_size to be an int. If keeping aspect
|
52 |
+
ratio, smallest dim will be set to target_size. Defaults to False.
|
53 |
+
"""
|
54 |
+
if isinstance(target_size, (int, tuple, list)):
|
55 |
+
if not isinstance(target_size, int) and not keep_aspect_ratio:
|
56 |
+
assert len(target_size) == 2
|
57 |
+
self.h, self.w = target_size
|
58 |
+
else:
|
59 |
+
if keep_aspect_ratio:
|
60 |
+
assert isinstance(target_size, int)
|
61 |
+
self.h = self.w = target_size
|
62 |
+
|
63 |
+
self.default_h = int(self.h)
|
64 |
+
self.default_w = int(self.w)
|
65 |
+
self.sizes = {}
|
66 |
+
elif isinstance(target_size, dict):
|
67 |
+
assert (
|
68 |
+
not keep_aspect_ratio
|
69 |
+
), "dict target_size not compatible with keep_aspect_ratio"
|
70 |
+
|
71 |
+
self.sizes = {
|
72 |
+
k: {"h": v, "w": v} for k, v in target_size.items() if k != "default"
|
73 |
+
}
|
74 |
+
self.default_h = int(target_size["default"])
|
75 |
+
self.default_w = int(target_size["default"])
|
76 |
+
|
77 |
+
self.keep_aspect_ratio = keep_aspect_ratio
|
78 |
+
|
79 |
+
def compute_new_default_size(self, tensor):
|
80 |
+
"""
|
81 |
+
compute the new size for a tensor depending on target size
|
82 |
+
and keep_aspect_rato
|
83 |
+
|
84 |
+
Args:
|
85 |
+
tensor (torch.Tensor): 4D tensor N x C x H x W.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
tuple(int): (new_height, new_width)
|
89 |
+
"""
|
90 |
+
if self.keep_aspect_ratio:
|
91 |
+
h, w = tensor.shape[-2:]
|
92 |
+
if h < w:
|
93 |
+
return (self.h, int(self.default_h * w / h))
|
94 |
+
else:
|
95 |
+
return (int(self.default_h * h / w), self.default_w)
|
96 |
+
return (self.default_h, self.default_w)
|
97 |
+
|
98 |
+
def compute_new_size_for_task(self, task):
|
99 |
+
assert (
|
100 |
+
not self.keep_aspect_ratio
|
101 |
+
), "compute_new_size_for_task is not compatible with keep aspect ratio"
|
102 |
+
|
103 |
+
if task not in self.sizes:
|
104 |
+
return (self.default_h, self.default_w)
|
105 |
+
|
106 |
+
return (self.sizes[task]["h"], self.sizes[task]["w"])
|
107 |
+
|
108 |
+
def __call__(self, data):
|
109 |
+
"""
|
110 |
+
Resize a dict of tensors to the "x" key's new_size
|
111 |
+
|
112 |
+
Args:
|
113 |
+
data (dict[str:torch.Tensor]): The data dict to transform
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
dict[str: torch.Tensor]: dict with all tensors resized to the
|
117 |
+
new size of the data["x"] tensor
|
118 |
+
"""
|
119 |
+
task = tensor = new_size = None
|
120 |
+
try:
|
121 |
+
if not self.sizes:
|
122 |
+
d = {}
|
123 |
+
new_size = self.compute_new_default_size(
|
124 |
+
data["x"] if "x" in data else list(data.values())[0]
|
125 |
+
)
|
126 |
+
for task, tensor in data.items():
|
127 |
+
d[task] = F.interpolate(
|
128 |
+
tensor, size=new_size, **interpolation(task)
|
129 |
+
)
|
130 |
+
return d
|
131 |
+
|
132 |
+
d = {}
|
133 |
+
for task, tensor in data.items():
|
134 |
+
new_size = self.compute_new_size_for_task(task)
|
135 |
+
d[task] = F.interpolate(tensor, size=new_size, **interpolation(task))
|
136 |
+
return d
|
137 |
+
|
138 |
+
except Exception as e:
|
139 |
+
tb = traceback.format_exc()
|
140 |
+
print("Debug: task, shape, interpolation, h, w, new_size")
|
141 |
+
print(task)
|
142 |
+
print(tensor.shape)
|
143 |
+
print(interpolation(task))
|
144 |
+
print(self.h, self.w)
|
145 |
+
print(new_size)
|
146 |
+
print(tb)
|
147 |
+
raise Exception(e)
|
148 |
+
|
149 |
+
|
150 |
+
class RandomCrop:
|
151 |
+
def __init__(self, size, center=False):
|
152 |
+
assert isinstance(size, (int, tuple, list))
|
153 |
+
if not isinstance(size, int):
|
154 |
+
assert len(size) == 2
|
155 |
+
self.h, self.w = size
|
156 |
+
else:
|
157 |
+
self.h = self.w = size
|
158 |
+
|
159 |
+
self.h = int(self.h)
|
160 |
+
self.w = int(self.w)
|
161 |
+
self.center = center
|
162 |
+
|
163 |
+
def __call__(self, data):
|
164 |
+
H, W = (
|
165 |
+
data["x"].size()[-2:] if "x" in data else list(data.values())[0].size()[-2:]
|
166 |
+
)
|
167 |
+
|
168 |
+
if not self.center:
|
169 |
+
top = np.random.randint(0, H - self.h)
|
170 |
+
left = np.random.randint(0, W - self.w)
|
171 |
+
else:
|
172 |
+
top = (H - self.h) // 2
|
173 |
+
left = (W - self.w) // 2
|
174 |
+
|
175 |
+
return {
|
176 |
+
task: tensor[:, :, top : top + self.h, left : left + self.w]
|
177 |
+
for task, tensor in data.items()
|
178 |
+
}
|
179 |
+
|
180 |
+
|
181 |
+
class RandomHorizontalFlip:
|
182 |
+
def __init__(self, p=0.5):
|
183 |
+
# self.flip = TF.hflip
|
184 |
+
self.p = p
|
185 |
+
|
186 |
+
def __call__(self, data):
|
187 |
+
if np.random.rand() > self.p:
|
188 |
+
return data
|
189 |
+
return {task: torch.flip(tensor, [3]) for task, tensor in data.items()}
|
190 |
+
|
191 |
+
|
192 |
+
class ToTensor:
|
193 |
+
def __init__(self):
|
194 |
+
self.ImagetoTensor = trsfs.ToTensor()
|
195 |
+
self.MaptoTensor = self.ImagetoTensor
|
196 |
+
|
197 |
+
def __call__(self, data):
|
198 |
+
new_data = {}
|
199 |
+
for task, im in data.items():
|
200 |
+
if task in {"x", "a"}:
|
201 |
+
new_data[task] = self.ImagetoTensor(im)
|
202 |
+
elif task in {"m"}:
|
203 |
+
new_data[task] = self.MaptoTensor(im)
|
204 |
+
elif task == "s":
|
205 |
+
new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to(
|
206 |
+
torch.int64
|
207 |
+
)
|
208 |
+
elif task == "d":
|
209 |
+
new_data = im
|
210 |
+
|
211 |
+
return new_data
|
212 |
+
|
213 |
+
|
214 |
+
class Normalize:
|
215 |
+
def __init__(self, opts):
|
216 |
+
if opts.data.normalization == "HRNet":
|
217 |
+
self.normImage = trsfs.Normalize(
|
218 |
+
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
219 |
+
)
|
220 |
+
else:
|
221 |
+
self.normImage = trsfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
222 |
+
self.normDepth = lambda x: x
|
223 |
+
self.normMask = lambda x: x
|
224 |
+
self.normSeg = lambda x: x
|
225 |
+
|
226 |
+
self.normalize = {
|
227 |
+
"x": self.normImage,
|
228 |
+
"s": self.normSeg,
|
229 |
+
"d": self.normDepth,
|
230 |
+
"m": self.normMask,
|
231 |
+
}
|
232 |
+
|
233 |
+
def __call__(self, data):
|
234 |
+
return {
|
235 |
+
task: self.normalize.get(task, lambda x: x)(tensor.squeeze(0))
|
236 |
+
for task, tensor in data.items()
|
237 |
+
}
|
238 |
+
|
239 |
+
|
240 |
+
class RandBrightness: # Input need to be between -1 and 1
|
241 |
+
def __call__(self, data):
|
242 |
+
return {
|
243 |
+
task: rand_brightness(tensor) if task == "x" else tensor
|
244 |
+
for task, tensor in data.items()
|
245 |
+
}
|
246 |
+
|
247 |
+
|
248 |
+
class RandSaturation:
|
249 |
+
def __call__(self, data):
|
250 |
+
return {
|
251 |
+
task: rand_saturation(tensor) if task == "x" else tensor
|
252 |
+
for task, tensor in data.items()
|
253 |
+
}
|
254 |
+
|
255 |
+
|
256 |
+
class RandContrast:
|
257 |
+
def __call__(self, data):
|
258 |
+
return {
|
259 |
+
task: rand_contrast(tensor) if task == "x" else tensor
|
260 |
+
for task, tensor in data.items()
|
261 |
+
}
|
262 |
+
|
263 |
+
|
264 |
+
class BucketizeDepth:
|
265 |
+
def __init__(self, opts, domain):
|
266 |
+
self.domain = domain
|
267 |
+
|
268 |
+
if opts.gen.d.classify.enable and domain in {"s", "kitti"}:
|
269 |
+
self.buckets = torch.linspace(
|
270 |
+
*[
|
271 |
+
opts.gen.d.classify.linspace.min,
|
272 |
+
opts.gen.d.classify.linspace.max,
|
273 |
+
opts.gen.d.classify.linspace.buckets - 1,
|
274 |
+
]
|
275 |
+
)
|
276 |
+
|
277 |
+
self.transforms = {
|
278 |
+
"d": lambda tensor: torch.bucketize(
|
279 |
+
tensor, self.buckets, out_int32=True, right=True
|
280 |
+
)
|
281 |
+
}
|
282 |
+
else:
|
283 |
+
self.transforms = {}
|
284 |
+
|
285 |
+
def __call__(self, data):
|
286 |
+
return {
|
287 |
+
task: self.transforms.get(task, lambda x: x)(tensor)
|
288 |
+
for task, tensor in data.items()
|
289 |
+
}
|
290 |
+
|
291 |
+
|
292 |
+
class PrepareInference:
|
293 |
+
"""
|
294 |
+
Transform which:
|
295 |
+
- transforms a str or an array into a tensor
|
296 |
+
- resizes the image to keep the aspect ratio
|
297 |
+
- crops in the center of the resized image
|
298 |
+
- normalize to 0:1
|
299 |
+
- rescale to -1:1
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(self, target_size=640, half=False, is_label=False, enforce_128=True):
|
303 |
+
if enforce_128:
|
304 |
+
if target_size % 2 ** 7 != 0:
|
305 |
+
raise ValueError(
|
306 |
+
f"Received a target_size of {target_size}, which is not a "
|
307 |
+
+ "multiple of 2^7 = 128. Set enforce_128 to False to disable "
|
308 |
+
+ "this error."
|
309 |
+
)
|
310 |
+
self.resize = Resize(target_size, keep_aspect_ratio=True)
|
311 |
+
self.crop = RandomCrop((target_size, target_size), center=True)
|
312 |
+
self.half = half
|
313 |
+
self.is_label = is_label
|
314 |
+
|
315 |
+
def process(self, t):
|
316 |
+
if isinstance(t, (str, Path)):
|
317 |
+
t = imread(str(t))
|
318 |
+
|
319 |
+
if isinstance(t, np.ndarray):
|
320 |
+
if t.shape[-1] == 4:
|
321 |
+
t = rgba2rgb(t)
|
322 |
+
|
323 |
+
t = torch.from_numpy(t)
|
324 |
+
if t.ndim == 3:
|
325 |
+
t = t.permute(2, 0, 1)
|
326 |
+
|
327 |
+
if t.ndim == 3:
|
328 |
+
t = t.unsqueeze(0)
|
329 |
+
elif t.ndim == 2:
|
330 |
+
t = t.unsqueeze(0).unsqueeze(0)
|
331 |
+
|
332 |
+
if not self.is_label:
|
333 |
+
t = t.to(torch.float32)
|
334 |
+
t = normalize(t)
|
335 |
+
t = (t - 0.5) * 2
|
336 |
+
|
337 |
+
t = {"m": t} if self.is_label else {"x": t}
|
338 |
+
t = self.resize(t)
|
339 |
+
t = self.crop(t)
|
340 |
+
t = t["m"] if self.is_label else t["x"]
|
341 |
+
|
342 |
+
if self.half and not self.is_label:
|
343 |
+
t = t.half()
|
344 |
+
|
345 |
+
return t
|
346 |
+
|
347 |
+
def __call__(self, x):
|
348 |
+
"""
|
349 |
+
normalize, rescale, resize, crop in the center
|
350 |
+
|
351 |
+
x can be: dict {"task": data} list [data, ..] or data
|
352 |
+
data ^ can be a str, a Path, a numpy arrray or a Tensor
|
353 |
+
"""
|
354 |
+
if isinstance(x, dict):
|
355 |
+
return {k: self.process(v) for k, v in x.items()}
|
356 |
+
|
357 |
+
if isinstance(x, list):
|
358 |
+
return [self.process(t) for t in x]
|
359 |
+
|
360 |
+
return self.process(x)
|
361 |
+
|
362 |
+
|
363 |
+
class PrepareTest:
|
364 |
+
"""
|
365 |
+
Transform which:
|
366 |
+
- transforms a str or an array into a tensor
|
367 |
+
- resizes the image to keep the aspect ratio
|
368 |
+
- crops in the center of the resized image
|
369 |
+
- normalize to 0:1 (optional)
|
370 |
+
- rescale to -1:1 (optional)
|
371 |
+
"""
|
372 |
+
|
373 |
+
def __init__(self, target_size=640, half=False):
|
374 |
+
self.resize = Resize(target_size, keep_aspect_ratio=True)
|
375 |
+
self.crop = RandomCrop((target_size, target_size), center=True)
|
376 |
+
self.half = half
|
377 |
+
|
378 |
+
def process(self, t, normalize=False, rescale=False):
|
379 |
+
if isinstance(t, (str, Path)):
|
380 |
+
# t = img_as_float(imread(str(t)))
|
381 |
+
t = imread(str(t))
|
382 |
+
if t.shape[-1] == 4:
|
383 |
+
# t = rgba2rgb(t)
|
384 |
+
t = t[:, :, :3]
|
385 |
+
if np.ndim(t) == 2:
|
386 |
+
t = np.repeat(t[:, :, np.newaxis], 3, axis=2)
|
387 |
+
|
388 |
+
if isinstance(t, np.ndarray):
|
389 |
+
t = torch.from_numpy(t)
|
390 |
+
t = t.permute(2, 0, 1)
|
391 |
+
|
392 |
+
if len(t.shape) == 3:
|
393 |
+
t = t.unsqueeze(0)
|
394 |
+
|
395 |
+
t = t.to(torch.float32)
|
396 |
+
normalize(t) if normalize else t
|
397 |
+
(t - 0.5) * 2 if rescale else t
|
398 |
+
t = {"x": t}
|
399 |
+
t = self.resize(t)
|
400 |
+
t = self.crop(t)
|
401 |
+
t = t["x"]
|
402 |
+
|
403 |
+
if self.half:
|
404 |
+
return t.to(torch.float16)
|
405 |
+
|
406 |
+
return t
|
407 |
+
|
408 |
+
def __call__(self, x, normalize=False, rescale=False):
|
409 |
+
"""
|
410 |
+
Call process()
|
411 |
+
|
412 |
+
x can be: dict {"task": data} list [data, ..] or data
|
413 |
+
data ^ can be a str, a Path, a numpy arrray or a Tensor
|
414 |
+
"""
|
415 |
+
if isinstance(x, dict):
|
416 |
+
return {k: self.process(v, normalize, rescale) for k, v in x.items()}
|
417 |
+
|
418 |
+
if isinstance(x, list):
|
419 |
+
return [self.process(t, normalize, rescale) for t in x]
|
420 |
+
|
421 |
+
return self.process(x, normalize, rescale)
|
422 |
+
|
423 |
+
|
424 |
+
def get_transform(transform_item, mode):
|
425 |
+
"""Returns the torchivion transform function associated to a
|
426 |
+
transform_item listed in opts.data.transforms ; transform_item is
|
427 |
+
an addict.Dict
|
428 |
+
"""
|
429 |
+
|
430 |
+
if transform_item.name == "crop" and not (
|
431 |
+
transform_item.ignore is True or transform_item.ignore == mode
|
432 |
+
):
|
433 |
+
return RandomCrop(
|
434 |
+
(transform_item.height, transform_item.width),
|
435 |
+
center=transform_item.center == mode,
|
436 |
+
)
|
437 |
+
|
438 |
+
elif transform_item.name == "resize" and not (
|
439 |
+
transform_item.ignore is True or transform_item.ignore == mode
|
440 |
+
):
|
441 |
+
return Resize(
|
442 |
+
transform_item.new_size, transform_item.get("keep_aspect_ratio", False)
|
443 |
+
)
|
444 |
+
|
445 |
+
elif transform_item.name == "hflip" and not (
|
446 |
+
transform_item.ignore is True or transform_item.ignore == mode
|
447 |
+
):
|
448 |
+
return RandomHorizontalFlip(p=transform_item.p or 0.5)
|
449 |
+
|
450 |
+
elif transform_item.name == "brightness" and not (
|
451 |
+
transform_item.ignore is True or transform_item.ignore == mode
|
452 |
+
):
|
453 |
+
return RandBrightness()
|
454 |
+
|
455 |
+
elif transform_item.name == "saturation" and not (
|
456 |
+
transform_item.ignore is True or transform_item.ignore == mode
|
457 |
+
):
|
458 |
+
return RandSaturation()
|
459 |
+
|
460 |
+
elif transform_item.name == "contrast" and not (
|
461 |
+
transform_item.ignore is True or transform_item.ignore == mode
|
462 |
+
):
|
463 |
+
return RandContrast()
|
464 |
+
|
465 |
+
elif transform_item.ignore is True or transform_item.ignore == mode:
|
466 |
+
return None
|
467 |
+
|
468 |
+
raise ValueError("Unknown transform_item {}".format(transform_item))
|
469 |
+
|
470 |
+
|
471 |
+
def get_transforms(opts, mode, domain):
|
472 |
+
"""Get all the transform functions listed in opts.data.transforms
|
473 |
+
using get_transform(transform_item, mode)
|
474 |
+
"""
|
475 |
+
transforms = []
|
476 |
+
color_jittering_transforms = ["brightness", "saturation", "contrast"]
|
477 |
+
|
478 |
+
for t in opts.data.transforms:
|
479 |
+
if t.name not in color_jittering_transforms:
|
480 |
+
transforms.append(get_transform(t, mode))
|
481 |
+
|
482 |
+
if "p" not in opts.tasks and mode == "train":
|
483 |
+
for t in opts.data.transforms:
|
484 |
+
if t.name in color_jittering_transforms:
|
485 |
+
transforms.append(get_transform(t, mode))
|
486 |
+
|
487 |
+
transforms += [Normalize(opts), BucketizeDepth(opts, domain)]
|
488 |
+
transforms = [t for t in transforms if t is not None]
|
489 |
+
|
490 |
+
return transforms
|
491 |
+
|
492 |
+
|
493 |
+
# ----- Adapted functions from https://github.com/mit-han-lab/data-efficient-gans -----#
|
494 |
+
def rand_brightness(tensor, is_diff_augment=False):
|
495 |
+
if is_diff_augment:
|
496 |
+
assert len(tensor.shape) == 4
|
497 |
+
type_ = tensor.dtype
|
498 |
+
device_ = tensor.device
|
499 |
+
rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
|
500 |
+
return tensor + (rand_tens - 0.5)
|
501 |
+
else:
|
502 |
+
factor = random.uniform(0.5, 1.5)
|
503 |
+
tensor = adjust_brightness(tensor, brightness_factor=factor)
|
504 |
+
# dummy pixels to fool scaling and preserve range
|
505 |
+
tensor[:, :, 0, 0] = 1.0
|
506 |
+
tensor[:, :, -1, -1] = 0.0
|
507 |
+
return tensor
|
508 |
+
|
509 |
+
|
510 |
+
def rand_saturation(tensor, is_diff_augment=False):
|
511 |
+
if is_diff_augment:
|
512 |
+
assert len(tensor.shape) == 4
|
513 |
+
type_ = tensor.dtype
|
514 |
+
device_ = tensor.device
|
515 |
+
rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
|
516 |
+
x_mean = tensor.mean(dim=1, keepdim=True)
|
517 |
+
return (tensor - x_mean) * (rand_tens * 2) + x_mean
|
518 |
+
else:
|
519 |
+
factor = random.uniform(0.5, 1.5)
|
520 |
+
tensor = adjust_saturation(tensor, saturation_factor=factor)
|
521 |
+
# dummy pixels to fool scaling and preserve range
|
522 |
+
tensor[:, :, 0, 0] = 1.0
|
523 |
+
tensor[:, :, -1, -1] = 0.0
|
524 |
+
return tensor
|
525 |
+
|
526 |
+
|
527 |
+
def rand_contrast(tensor, is_diff_augment=False):
|
528 |
+
if is_diff_augment:
|
529 |
+
assert len(tensor.shape) == 4
|
530 |
+
type_ = tensor.dtype
|
531 |
+
device_ = tensor.device
|
532 |
+
rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
|
533 |
+
x_mean = tensor.mean(dim=[1, 2, 3], keepdim=True)
|
534 |
+
return (tensor - x_mean) * (rand_tens + 0.5) + x_mean
|
535 |
+
else:
|
536 |
+
factor = random.uniform(0.5, 1.5)
|
537 |
+
tensor = adjust_contrast(tensor, contrast_factor=factor)
|
538 |
+
# dummy pixels to fool scaling and preserve range
|
539 |
+
tensor[:, :, 0, 0] = 1.0
|
540 |
+
tensor[:, :, -1, -1] = 0.0
|
541 |
+
return tensor
|
542 |
+
|
543 |
+
|
544 |
+
def rand_cutout(tensor, ratio=0.5):
|
545 |
+
assert len(tensor.shape) == 4, "For rand cutout, tensor must be 4D."
|
546 |
+
type_ = tensor.dtype
|
547 |
+
device_ = tensor.device
|
548 |
+
cutout_size = int(tensor.size(-2) * ratio + 0.5), int(tensor.size(-1) * ratio + 0.5)
|
549 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
550 |
+
torch.arange(tensor.size(0), dtype=torch.long, device=device_),
|
551 |
+
torch.arange(cutout_size[0], dtype=torch.long, device=device_),
|
552 |
+
torch.arange(cutout_size[1], dtype=torch.long, device=device_),
|
553 |
+
)
|
554 |
+
size_ = [tensor.size(0), 1, 1]
|
555 |
+
offset_x = torch.randint(
|
556 |
+
0,
|
557 |
+
tensor.size(-2) + (1 - cutout_size[0] % 2),
|
558 |
+
size=size_,
|
559 |
+
device=device_,
|
560 |
+
)
|
561 |
+
offset_y = torch.randint(
|
562 |
+
0,
|
563 |
+
tensor.size(-1) + (1 - cutout_size[1] % 2),
|
564 |
+
size=size_,
|
565 |
+
device=device_,
|
566 |
+
)
|
567 |
+
grid_x = torch.clamp(
|
568 |
+
grid_x + offset_x - cutout_size[0] // 2, min=0, max=tensor.size(-2) - 1
|
569 |
+
)
|
570 |
+
grid_y = torch.clamp(
|
571 |
+
grid_y + offset_y - cutout_size[1] // 2, min=0, max=tensor.size(-1) - 1
|
572 |
+
)
|
573 |
+
mask = torch.ones(
|
574 |
+
tensor.size(0), tensor.size(2), tensor.size(3), dtype=type_, device=device_
|
575 |
+
)
|
576 |
+
mask[grid_batch, grid_x, grid_y] = 0
|
577 |
+
return tensor * mask.unsqueeze(1)
|
578 |
+
|
579 |
+
|
580 |
+
def rand_translation(tensor, ratio=0.125):
|
581 |
+
assert len(tensor.shape) == 4, "For rand translation, tensor must be 4D."
|
582 |
+
device_ = tensor.device
|
583 |
+
shift_x, shift_y = (
|
584 |
+
int(tensor.size(2) * ratio + 0.5),
|
585 |
+
int(tensor.size(3) * ratio + 0.5),
|
586 |
+
)
|
587 |
+
translation_x = torch.randint(
|
588 |
+
-shift_x, shift_x + 1, size=[tensor.size(0), 1, 1], device=device_
|
589 |
+
)
|
590 |
+
translation_y = torch.randint(
|
591 |
+
-shift_y, shift_y + 1, size=[tensor.size(0), 1, 1], device=device_
|
592 |
+
)
|
593 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
594 |
+
torch.arange(tensor.size(0), dtype=torch.long, device=device_),
|
595 |
+
torch.arange(tensor.size(2), dtype=torch.long, device=device_),
|
596 |
+
torch.arange(tensor.size(3), dtype=torch.long, device=device_),
|
597 |
+
)
|
598 |
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, tensor.size(2) + 1)
|
599 |
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, tensor.size(3) + 1)
|
600 |
+
x_pad = F.pad(tensor, [1, 1, 1, 1, 0, 0, 0, 0])
|
601 |
+
tensor = (
|
602 |
+
x_pad.permute(0, 2, 3, 1)
|
603 |
+
.contiguous()[grid_batch, grid_x, grid_y]
|
604 |
+
.permute(0, 3, 1, 2)
|
605 |
+
)
|
606 |
+
return tensor
|
607 |
+
|
608 |
+
|
609 |
+
class DiffTransforms:
|
610 |
+
def __init__(self, diff_aug_opts):
|
611 |
+
self.do_color_jittering = diff_aug_opts.do_color_jittering
|
612 |
+
self.do_cutout = diff_aug_opts.do_cutout
|
613 |
+
self.do_translation = diff_aug_opts.do_translation
|
614 |
+
self.cutout_ratio = diff_aug_opts.cutout_ratio
|
615 |
+
self.translation_ratio = diff_aug_opts.translation_ratio
|
616 |
+
|
617 |
+
def __call__(self, tensor):
|
618 |
+
if self.do_color_jittering:
|
619 |
+
tensor = rand_brightness(tensor, is_diff_augment=True)
|
620 |
+
tensor = rand_contrast(tensor, is_diff_augment=True)
|
621 |
+
tensor = rand_saturation(tensor, is_diff_augment=True)
|
622 |
+
if self.do_translation:
|
623 |
+
tensor = rand_translation(tensor, ratio=self.translation_ratio)
|
624 |
+
if self.do_cutout:
|
625 |
+
tensor = rand_cutout(tensor, ratio=self.cutout_ratio)
|
626 |
+
return tensor
|
climategan/tutils.py
ADDED
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Tensor-utils
|
2 |
+
"""
|
3 |
+
import io
|
4 |
+
import math
|
5 |
+
from contextlib import redirect_stdout
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# from copy import copy
|
9 |
+
from threading import Thread
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from skimage import io as skio
|
15 |
+
from torch import autograd
|
16 |
+
from torch.autograd import Variable
|
17 |
+
from torch.nn import init
|
18 |
+
|
19 |
+
from climategan.utils import all_texts_to_array
|
20 |
+
|
21 |
+
|
22 |
+
def transforms_string(ts):
|
23 |
+
return " -> ".join([t.__class__.__name__ for t in ts.transforms])
|
24 |
+
|
25 |
+
|
26 |
+
def init_weights(net, init_type="normal", init_gain=0.02, verbose=0, caller=""):
|
27 |
+
"""Initialize network weights.
|
28 |
+
Parameters:
|
29 |
+
net (network) -- network to be initialized
|
30 |
+
init_type (str) -- the name of an initialization method:
|
31 |
+
normal | xavier | kaiming | orthogonal
|
32 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
33 |
+
|
34 |
+
We use 'normal' in the original pix2pix and CycleGAN paper.
|
35 |
+
But xavier and kaiming might work better for some applications.
|
36 |
+
Feel free to try yourself.
|
37 |
+
"""
|
38 |
+
|
39 |
+
if not init_type:
|
40 |
+
print(
|
41 |
+
"init_weights({}): init_type is {}, defaulting to normal".format(
|
42 |
+
caller + " " + net.__class__.__name__, init_type
|
43 |
+
)
|
44 |
+
)
|
45 |
+
init_type = "normal"
|
46 |
+
if not init_gain:
|
47 |
+
print(
|
48 |
+
"init_weights({}): init_gain is {}, defaulting to normal".format(
|
49 |
+
caller + " " + net.__class__.__name__, init_type
|
50 |
+
)
|
51 |
+
)
|
52 |
+
init_gain = 0.02
|
53 |
+
|
54 |
+
def init_func(m):
|
55 |
+
classname = m.__class__.__name__
|
56 |
+
if classname.find("BatchNorm2d") != -1:
|
57 |
+
if hasattr(m, "weight") and m.weight is not None:
|
58 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
59 |
+
if hasattr(m, "bias") and m.bias is not None:
|
60 |
+
init.constant_(m.bias.data, 0.0)
|
61 |
+
elif hasattr(m, "weight") and (
|
62 |
+
classname.find("Conv") != -1 or classname.find("Linear") != -1
|
63 |
+
):
|
64 |
+
if init_type == "normal":
|
65 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
66 |
+
elif init_type == "xavier":
|
67 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
68 |
+
elif init_type == "xavier_uniform":
|
69 |
+
init.xavier_uniform_(m.weight.data, gain=1.0)
|
70 |
+
elif init_type == "kaiming":
|
71 |
+
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
72 |
+
elif init_type == "orthogonal":
|
73 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
74 |
+
elif init_type == "none": # uses pytorch's default init method
|
75 |
+
m.reset_parameters()
|
76 |
+
else:
|
77 |
+
raise NotImplementedError(
|
78 |
+
"initialization method [%s] is not implemented" % init_type
|
79 |
+
)
|
80 |
+
if hasattr(m, "bias") and m.bias is not None:
|
81 |
+
init.constant_(m.bias.data, 0.0)
|
82 |
+
|
83 |
+
if verbose > 0:
|
84 |
+
print("initialize %s with %s" % (net.__class__.__name__, init_type))
|
85 |
+
net.apply(init_func)
|
86 |
+
|
87 |
+
|
88 |
+
def domains_to_class_tensor(domains, one_hot=False):
|
89 |
+
"""Converts a list of strings to a 1D Tensor representing the domains
|
90 |
+
|
91 |
+
domains_to_class_tensor(["sf", "rn"])
|
92 |
+
>>> torch.Tensor([2, 1])
|
93 |
+
|
94 |
+
Args:
|
95 |
+
domain (list(str)): each element of the list should be in {rf, rn, sf, sn}
|
96 |
+
one_hot (bool, optional): whether or not to 1-h encode class labels.
|
97 |
+
Defaults to False.
|
98 |
+
Raises:
|
99 |
+
ValueError: One of the domains listed is not in {rf, rn, sf, sn}
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
torch.Tensor: 1D tensor mapping a domain to an int (not 1-hot) or 1-hot
|
103 |
+
domain labels in a 2D tensor
|
104 |
+
"""
|
105 |
+
|
106 |
+
mapping = {"r": 0, "s": 1}
|
107 |
+
|
108 |
+
if not all(domain in mapping for domain in domains):
|
109 |
+
raise ValueError(
|
110 |
+
"Unknown domains {} should be in {}".format(domains, list(mapping.keys()))
|
111 |
+
)
|
112 |
+
|
113 |
+
target = torch.tensor([mapping[domain] for domain in domains])
|
114 |
+
|
115 |
+
if one_hot:
|
116 |
+
one_hot_target = torch.FloatTensor(len(target), 2) # 2 domains
|
117 |
+
one_hot_target.zero_()
|
118 |
+
one_hot_target.scatter_(1, target.unsqueeze(1), 1)
|
119 |
+
# https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507
|
120 |
+
target = one_hot_target
|
121 |
+
return target
|
122 |
+
|
123 |
+
|
124 |
+
def fake_domains_to_class_tensor(domains, one_hot=False):
|
125 |
+
"""Converts a list of strings to a 1D Tensor representing the fake domains
|
126 |
+
(real or sim only)
|
127 |
+
|
128 |
+
fake_domains_to_class_tensor(["s", "r"], False)
|
129 |
+
>>> torch.Tensor([0, 2])
|
130 |
+
|
131 |
+
|
132 |
+
Args:
|
133 |
+
domain (list(str)): each element of the list should be in {r, s}
|
134 |
+
one_hot (bool, optional): whether or not to 1-h encode class labels.
|
135 |
+
Defaults to False.
|
136 |
+
Raises:
|
137 |
+
ValueError: One of the domains listed is not in {rf, rn, sf, sn}
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
torch.Tensor: 1D tensor mapping a domain to an int (not 1-hot) or
|
141 |
+
a 2D tensor filled with 0.25 to fool the classifier (equiprobability
|
142 |
+
for each domain).
|
143 |
+
"""
|
144 |
+
if one_hot:
|
145 |
+
target = torch.FloatTensor(len(domains), 2)
|
146 |
+
target.fill_(0.5)
|
147 |
+
|
148 |
+
else:
|
149 |
+
mapping = {"r": 1, "s": 0}
|
150 |
+
|
151 |
+
if not all(domain in mapping for domain in domains):
|
152 |
+
raise ValueError(
|
153 |
+
"Unknown domains {} should be in {}".format(
|
154 |
+
domains, list(mapping.keys())
|
155 |
+
)
|
156 |
+
)
|
157 |
+
|
158 |
+
target = torch.tensor([mapping[domain] for domain in domains])
|
159 |
+
return target
|
160 |
+
|
161 |
+
|
162 |
+
def show_tanh_tensor(tensor):
|
163 |
+
import skimage
|
164 |
+
|
165 |
+
if isinstance(tensor, torch.Tensor):
|
166 |
+
image = tensor.permute(1, 2, 0).detach().numpy()
|
167 |
+
else:
|
168 |
+
image = tensor
|
169 |
+
if image.shape[-1] != 3:
|
170 |
+
image = image.transpose(1, 2, 0)
|
171 |
+
|
172 |
+
if image.min() < 0 and image.min() > -1:
|
173 |
+
image = image / 2 + 0.5
|
174 |
+
elif image.min() < -1:
|
175 |
+
raise ValueError("can't handle this data")
|
176 |
+
|
177 |
+
skimage.io.imshow(image)
|
178 |
+
|
179 |
+
|
180 |
+
def normalize_tensor(t):
|
181 |
+
"""
|
182 |
+
Brings any tensor to the [0; 1] range.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
t (torch.Tensor): input to normalize
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
torch.Tensor: t projected to [0; 1]
|
189 |
+
"""
|
190 |
+
t = t - torch.min(t)
|
191 |
+
t = t / torch.max(t)
|
192 |
+
return t
|
193 |
+
|
194 |
+
|
195 |
+
def get_normalized_depth_t(tensor, domain, normalize=False, log=True):
|
196 |
+
assert not (normalize and log)
|
197 |
+
if domain == "r":
|
198 |
+
# megadepth depth
|
199 |
+
tensor = tensor.unsqueeze(0)
|
200 |
+
tensor = tensor - torch.min(tensor)
|
201 |
+
tensor = torch.true_divide(tensor, torch.max(tensor))
|
202 |
+
|
203 |
+
elif domain == "s":
|
204 |
+
# from 3-channel depth encoding from Unity simulator to 1-channel [0-1] values
|
205 |
+
tensor = decode_unity_depth_t(tensor, log=log, normalize=normalize)
|
206 |
+
|
207 |
+
elif domain == "kitti":
|
208 |
+
tensor = tensor / 100
|
209 |
+
if not log:
|
210 |
+
tensor = 1 / tensor
|
211 |
+
if normalize:
|
212 |
+
tensor = tensor - tensor.min()
|
213 |
+
tensor = tensor / tensor.max()
|
214 |
+
else:
|
215 |
+
tensor = torch.log(tensor)
|
216 |
+
|
217 |
+
tensor = tensor.unsqueeze(0)
|
218 |
+
|
219 |
+
return tensor
|
220 |
+
|
221 |
+
|
222 |
+
def decode_bucketed_depth(tensor, opts):
|
223 |
+
# tensor is size 1 x C x H x W
|
224 |
+
assert tensor.shape[0] == 1
|
225 |
+
idx = torch.argmax(tensor.squeeze(0), dim=0) # channels become dim 0 with squeeze
|
226 |
+
linspace_args = (
|
227 |
+
opts.gen.d.classify.linspace.min,
|
228 |
+
opts.gen.d.classify.linspace.max,
|
229 |
+
opts.gen.d.classify.linspace.buckets,
|
230 |
+
)
|
231 |
+
indexer = torch.linspace(*linspace_args)
|
232 |
+
log_depth = indexer[idx.long()].to(torch.float32) # H x W
|
233 |
+
depth = torch.exp(log_depth)
|
234 |
+
return depth.unsqueeze(0).unsqueeze(0).to(tensor.device)
|
235 |
+
|
236 |
+
|
237 |
+
def decode_unity_depth_t(unity_depth, log=True, normalize=False, numpy=False, far=1000):
|
238 |
+
"""Transforms the 3-channel encoded depth map from our Unity simulator
|
239 |
+
to 1-channel depth map containing metric depth values.
|
240 |
+
The depth is encoded in the following way:
|
241 |
+
- The information from the simulator is (1 - LinearDepth (in [0,1])).
|
242 |
+
far corresponds to the furthest distance to the camera included in the
|
243 |
+
depth map.
|
244 |
+
LinearDepth * far gives the real metric distance to the camera.
|
245 |
+
- depth is first divided in 31 slices encoded in R channel with values ranging
|
246 |
+
from 0 to 247
|
247 |
+
- each slice is divided again in 31 slices, whose value is encoded in G channel
|
248 |
+
- each of the G slices is divided into 256 slices, encoded in B channel
|
249 |
+
|
250 |
+
In total, we have a discretization of depth into N = 31*31*256 - 1 possible values,
|
251 |
+
covering a range of far/N meters.
|
252 |
+
|
253 |
+
Note that, what we encode here is 1 - LinearDepth so that the furthest point is
|
254 |
+
[0,0,0] (that is sky) and the closest point[255,255,255]
|
255 |
+
|
256 |
+
The metric distance associated to a pixel whose depth is (R,G,B) is :
|
257 |
+
d = (far/N) * [((255 - R)//8)*256*31 + ((255 - G)//8)*256 + (255 - B)]
|
258 |
+
|
259 |
+
* torch.Tensor in [0, 1] as torch.float32 if numpy == False
|
260 |
+
|
261 |
+
* else numpy.array in [0, 255] as np.uint8
|
262 |
+
|
263 |
+
Args:
|
264 |
+
unity_depth (torch.Tensor): one depth map obtained from our simulator
|
265 |
+
numpy (bool, optional): Whether to return a float tensor or an int array.
|
266 |
+
Defaults to False.
|
267 |
+
far: far parameter of the camera in Unity simulator.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
[torch.Tensor or numpy.array]: decoded depth
|
271 |
+
"""
|
272 |
+
R = unity_depth[:, :, 0]
|
273 |
+
G = unity_depth[:, :, 1]
|
274 |
+
B = unity_depth[:, :, 2]
|
275 |
+
|
276 |
+
R = ((247 - R) / 8).type(torch.IntTensor)
|
277 |
+
G = ((247 - G) / 8).type(torch.IntTensor)
|
278 |
+
B = (255 - B).type(torch.IntTensor)
|
279 |
+
depth = ((R * 256 * 31 + G * 256 + B).type(torch.FloatTensor)) / (256 * 31 * 31 - 1)
|
280 |
+
depth = depth * far
|
281 |
+
if not log:
|
282 |
+
depth = 1 / depth
|
283 |
+
depth = depth.unsqueeze(0) # (depth * far).unsqueeze(0)
|
284 |
+
|
285 |
+
if log:
|
286 |
+
depth = torch.log(depth)
|
287 |
+
if normalize:
|
288 |
+
depth = depth - torch.min(depth)
|
289 |
+
depth /= torch.max(depth)
|
290 |
+
if numpy:
|
291 |
+
depth = depth.data.cpu().numpy()
|
292 |
+
return depth.astype(np.uint8).squeeze()
|
293 |
+
return depth
|
294 |
+
|
295 |
+
|
296 |
+
def to_inv_depth(log_depth, numpy=False):
|
297 |
+
"""Convert log depth tensor to inverse depth image for display
|
298 |
+
|
299 |
+
Args:
|
300 |
+
depth (Tensor): log depth float tensor
|
301 |
+
"""
|
302 |
+
depth = torch.exp(log_depth)
|
303 |
+
# visualize prediction using inverse depth, so that we don't need sky
|
304 |
+
# segmentation (if you want to use RGB map for visualization,
|
305 |
+
# you have to run semantic segmentation to mask the sky first
|
306 |
+
# since the depth of sky is random from CNN)
|
307 |
+
inv_depth = 1 / depth
|
308 |
+
inv_depth /= torch.max(inv_depth)
|
309 |
+
if numpy:
|
310 |
+
inv_depth = inv_depth.data.cpu().numpy()
|
311 |
+
# you might also use percentile for better visualization
|
312 |
+
|
313 |
+
return inv_depth
|
314 |
+
|
315 |
+
|
316 |
+
def shuffle_batch_tuple(mbt):
|
317 |
+
"""shuffle the order of domains in the batch
|
318 |
+
|
319 |
+
Args:
|
320 |
+
mbt (tuple): multi-batch tuple
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
list: randomized list of domain-specific batches
|
324 |
+
"""
|
325 |
+
assert isinstance(mbt, (tuple, list))
|
326 |
+
assert len(mbt) > 0
|
327 |
+
perm = np.random.permutation(len(mbt))
|
328 |
+
return [mbt[i] for i in perm]
|
329 |
+
|
330 |
+
|
331 |
+
def slice_batch(batch, slice_size):
|
332 |
+
assert slice_size > 0
|
333 |
+
for k, v in batch.items():
|
334 |
+
if isinstance(v, dict):
|
335 |
+
for task, d in v.items():
|
336 |
+
batch[k][task] = d[:slice_size]
|
337 |
+
else:
|
338 |
+
batch[k] = v[:slice_size]
|
339 |
+
return batch
|
340 |
+
|
341 |
+
|
342 |
+
def save_tanh_tensor(image, path):
|
343 |
+
"""Save an image which can be numpy or tensor, 2 or 3 dims (no batch)
|
344 |
+
to path.
|
345 |
+
|
346 |
+
Args:
|
347 |
+
image (np.array or torch.Tensor): image to save
|
348 |
+
path (pathlib.Path or str): where to save the image
|
349 |
+
"""
|
350 |
+
path = Path(path)
|
351 |
+
if isinstance(image, torch.Tensor):
|
352 |
+
image = image.detach().cpu().numpy()
|
353 |
+
if image.shape[-1] != 3 and image.shape[0] == 3:
|
354 |
+
image = np.transpose(image, (1, 2, 0))
|
355 |
+
if image.min() < 0 and image.min() > -1:
|
356 |
+
image = image / 2 + 0.5
|
357 |
+
elif image.min() < -1:
|
358 |
+
image -= image.min()
|
359 |
+
image /= image.max()
|
360 |
+
# print("Warning: scaling image data in save_tanh_tensor")
|
361 |
+
|
362 |
+
skio.imsave(path, (image * 255).astype(np.uint8))
|
363 |
+
|
364 |
+
|
365 |
+
def save_batch(multi_domain_batch, root="./", step=0, num_threads=5):
|
366 |
+
root = Path(root)
|
367 |
+
root.mkdir(parents=True, exist_ok=True)
|
368 |
+
images_to_save = {"paths": [], "images": []}
|
369 |
+
for domain, batch in multi_domain_batch.items():
|
370 |
+
y = batch["data"].get("y")
|
371 |
+
x = batch["data"]["x"]
|
372 |
+
if y is not None:
|
373 |
+
paths = batch["paths"]["x"]
|
374 |
+
imtensor = torch.cat([x, y], dim=-1)
|
375 |
+
for i, im in enumerate(imtensor):
|
376 |
+
imid = Path(paths[i]).stem[:10]
|
377 |
+
images_to_save["paths"] += [
|
378 |
+
root / "im_{}_{}_{}.png".format(step, domain, imid)
|
379 |
+
]
|
380 |
+
images_to_save["images"].append(im)
|
381 |
+
if num_threads > 0:
|
382 |
+
threaded_write(images_to_save["images"], images_to_save["paths"], num_threads)
|
383 |
+
else:
|
384 |
+
for im, path in zip(images_to_save["images"], images_to_save["paths"]):
|
385 |
+
save_tanh_tensor(im, path)
|
386 |
+
|
387 |
+
|
388 |
+
def threaded_write(images, paths, num_threads=5):
|
389 |
+
t_im = []
|
390 |
+
t_p = []
|
391 |
+
for im, p in zip(images, paths):
|
392 |
+
t_im.append(im)
|
393 |
+
t_p.append(p)
|
394 |
+
if len(t_im) == num_threads:
|
395 |
+
ts = [
|
396 |
+
Thread(target=save_tanh_tensor, args=(_i, _p))
|
397 |
+
for _i, _p in zip(t_im, t_p)
|
398 |
+
]
|
399 |
+
list(map(lambda t: t.start(), ts))
|
400 |
+
list(map(lambda t: t.join(), ts))
|
401 |
+
t_im = []
|
402 |
+
t_p = []
|
403 |
+
if t_im:
|
404 |
+
ts = [
|
405 |
+
Thread(target=save_tanh_tensor, args=(_i, _p)) for _i, _p in zip(t_im, t_p)
|
406 |
+
]
|
407 |
+
list(map(lambda t: t.start(), ts))
|
408 |
+
list(map(lambda t: t.join(), ts))
|
409 |
+
|
410 |
+
|
411 |
+
def get_num_params(model):
|
412 |
+
total_params = sum(p.numel() for p in model.parameters())
|
413 |
+
return total_params
|
414 |
+
|
415 |
+
|
416 |
+
def vgg_preprocess(batch):
|
417 |
+
"""Preprocess batch to use VGG model"""
|
418 |
+
tensortype = type(batch.data)
|
419 |
+
(r, g, b) = torch.chunk(batch, 3, dim=1)
|
420 |
+
batch = torch.cat((b, g, r), dim=1) # convert RGB to BGR
|
421 |
+
batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255]
|
422 |
+
mean = tensortype(batch.data.size()).cuda()
|
423 |
+
mean[:, 0, :, :] = 103.939
|
424 |
+
mean[:, 1, :, :] = 116.779
|
425 |
+
mean[:, 2, :, :] = 123.680
|
426 |
+
batch = batch.sub(Variable(mean)) # subtract mean
|
427 |
+
return batch
|
428 |
+
|
429 |
+
|
430 |
+
def zero_grad(model: nn.Module):
|
431 |
+
"""
|
432 |
+
Sets gradients to None. Mode efficient than model.zero_grad()
|
433 |
+
or opt.zero_grad() according to https://www.youtube.com/watch?v=9mS1fIYj1So
|
434 |
+
|
435 |
+
Args:
|
436 |
+
model (nn.Module): model to zero out
|
437 |
+
"""
|
438 |
+
for p in model.parameters():
|
439 |
+
p.grad = None
|
440 |
+
|
441 |
+
|
442 |
+
# Take the prediction of fake and real images from the combined batch
|
443 |
+
def divide_pred(disc_output):
|
444 |
+
"""
|
445 |
+
Divide a multiscale discriminator's output into 2 sets of tensors,
|
446 |
+
expecting the input to the discriminator to be a concatenation
|
447 |
+
on the batch axis of real and fake (or fake and real) images,
|
448 |
+
effectively doubling the batch size for better batchnorm statistics
|
449 |
+
|
450 |
+
Args:
|
451 |
+
disc_output (list | torch.Tensor): Discriminator output to split
|
452 |
+
|
453 |
+
Returns:
|
454 |
+
list | torch.Tensor[type]: pair of split outputs
|
455 |
+
"""
|
456 |
+
# https://github.com/NVlabs/SPADE/blob/master/models/pix2pix_model.py
|
457 |
+
# the prediction contains the intermediate outputs of multiscale GAN,
|
458 |
+
# so it's usually a list
|
459 |
+
if type(disc_output) == list:
|
460 |
+
half1 = []
|
461 |
+
half2 = []
|
462 |
+
for p in disc_output:
|
463 |
+
half1.append([tensor[: tensor.size(0) // 2] for tensor in p])
|
464 |
+
half2.append([tensor[tensor.size(0) // 2 :] for tensor in p])
|
465 |
+
else:
|
466 |
+
half1 = disc_output[: disc_output.size(0) // 2]
|
467 |
+
half2 = disc_output[disc_output.size(0) // 2 :]
|
468 |
+
|
469 |
+
return half1, half2
|
470 |
+
|
471 |
+
|
472 |
+
def is_tpu_available():
|
473 |
+
_torch_tpu_available = False
|
474 |
+
try:
|
475 |
+
import torch_xla.core.xla_model as xm # type: ignore
|
476 |
+
|
477 |
+
if "xla" in str(xm.xla_device()):
|
478 |
+
_torch_tpu_available = True
|
479 |
+
else:
|
480 |
+
_torch_tpu_available = False
|
481 |
+
except ImportError:
|
482 |
+
_torch_tpu_available = False
|
483 |
+
|
484 |
+
return _torch_tpu_available
|
485 |
+
|
486 |
+
|
487 |
+
def get_WGAN_gradient(input, output):
|
488 |
+
# github code reference:
|
489 |
+
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
|
490 |
+
# Calculate the gradient that WGAN-gp needs
|
491 |
+
grads = autograd.grad(
|
492 |
+
outputs=output,
|
493 |
+
inputs=input,
|
494 |
+
grad_outputs=torch.ones(output.size()).cuda(),
|
495 |
+
create_graph=True,
|
496 |
+
retain_graph=True,
|
497 |
+
only_inputs=True,
|
498 |
+
)[0]
|
499 |
+
grads = grads.view(grads.size(0), -1)
|
500 |
+
gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
|
501 |
+
return gp
|
502 |
+
|
503 |
+
|
504 |
+
def print_num_parameters(trainer, force=False):
|
505 |
+
if trainer.verbose == 0 and not force:
|
506 |
+
return
|
507 |
+
print("-" * 35)
|
508 |
+
if trainer.G.encoder is not None:
|
509 |
+
print(
|
510 |
+
"{:21}:".format("num params encoder"),
|
511 |
+
f"{get_num_params(trainer.G.encoder):12,}",
|
512 |
+
)
|
513 |
+
for d in trainer.G.decoders.keys():
|
514 |
+
print(
|
515 |
+
"{:21}:".format(f"num params decoder {d}"),
|
516 |
+
f"{get_num_params(trainer.G.decoders[d]):12,}",
|
517 |
+
)
|
518 |
+
|
519 |
+
print(
|
520 |
+
"{:21}:".format("num params painter"),
|
521 |
+
f"{get_num_params(trainer.G.painter):12,}",
|
522 |
+
)
|
523 |
+
|
524 |
+
if trainer.D is not None:
|
525 |
+
for d in trainer.D.keys():
|
526 |
+
print(
|
527 |
+
"{:21}:".format(f"num params discrim {d}"),
|
528 |
+
f"{get_num_params(trainer.D[d]):12,}",
|
529 |
+
)
|
530 |
+
|
531 |
+
print("-" * 35)
|
532 |
+
|
533 |
+
|
534 |
+
def srgb2lrgb(x):
|
535 |
+
x = normalize(x)
|
536 |
+
im = ((x + 0.055) / 1.055) ** (2.4)
|
537 |
+
im[x <= 0.04045] = x[x <= 0.04045] / 12.92
|
538 |
+
return im
|
539 |
+
|
540 |
+
|
541 |
+
def lrgb2srgb(ims):
|
542 |
+
if len(ims.shape) == 3:
|
543 |
+
ims = [ims]
|
544 |
+
stack = False
|
545 |
+
else:
|
546 |
+
ims = list(ims)
|
547 |
+
stack = True
|
548 |
+
|
549 |
+
outs = []
|
550 |
+
for im in ims:
|
551 |
+
|
552 |
+
out = torch.zeros_like(im)
|
553 |
+
for k in range(3):
|
554 |
+
temp = im[k, :, :]
|
555 |
+
|
556 |
+
out[k, :, :] = 12.92 * temp * (temp <= 0.0031308) + (
|
557 |
+
1.055 * torch.pow(temp, (1 / 2.4)) - 0.055
|
558 |
+
) * (temp > 0.0031308)
|
559 |
+
outs.append(out)
|
560 |
+
|
561 |
+
if stack:
|
562 |
+
return torch.stack(outs)
|
563 |
+
|
564 |
+
return outs[0]
|
565 |
+
|
566 |
+
|
567 |
+
def normalize(t, mini=0, maxi=1):
|
568 |
+
if len(t.shape) == 3:
|
569 |
+
return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
|
570 |
+
|
571 |
+
batch_size = t.shape[0]
|
572 |
+
min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, 1, 1, 1)
|
573 |
+
t = t - min_t
|
574 |
+
max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, 1, 1, 1)
|
575 |
+
t = t / max_t
|
576 |
+
return mini + (maxi - mini) * t
|
577 |
+
|
578 |
+
|
579 |
+
def retrieve_sky_mask(seg):
|
580 |
+
"""
|
581 |
+
get the binary mask for the sky given a segmentation tensor
|
582 |
+
of logits (N x C x H x W) or labels (N x H x W)
|
583 |
+
|
584 |
+
Args:
|
585 |
+
seg (torch.Tensor): Segmentation map
|
586 |
+
|
587 |
+
Returns:
|
588 |
+
torch.Tensor: Sky mask
|
589 |
+
"""
|
590 |
+
if len(seg.shape) == 4: # Predictions
|
591 |
+
seg_ind = torch.argmax(seg, dim=1)
|
592 |
+
else:
|
593 |
+
seg_ind = seg
|
594 |
+
|
595 |
+
sky_mask = seg_ind == 9
|
596 |
+
return sky_mask
|
597 |
+
|
598 |
+
|
599 |
+
def all_texts_to_tensors(texts, width=640, height=40):
|
600 |
+
"""
|
601 |
+
Creates a list of tensors with texts from PIL images
|
602 |
+
|
603 |
+
Args:
|
604 |
+
texts (list(str)): texts to write
|
605 |
+
width (int, optional): width of individual texts. Defaults to 640.
|
606 |
+
height (int, optional): height of individual texts. Defaults to 40.
|
607 |
+
|
608 |
+
Returns:
|
609 |
+
list(torch.Tensor): len(texts) tensors 3 x height x width
|
610 |
+
"""
|
611 |
+
arrays = all_texts_to_array(texts, width, height)
|
612 |
+
arrays = [array.transpose(2, 0, 1) for array in arrays]
|
613 |
+
return [torch.tensor(array) for array in arrays]
|
614 |
+
|
615 |
+
|
616 |
+
def write_architecture(trainer):
|
617 |
+
stem = "archi"
|
618 |
+
out = Path(trainer.opts.output_path)
|
619 |
+
|
620 |
+
# encoder
|
621 |
+
with open(out / f"{stem}_encoder.txt", "w") as f:
|
622 |
+
f.write(str(trainer.G.encoder))
|
623 |
+
|
624 |
+
# decoders
|
625 |
+
for k, v in trainer.G.decoders.items():
|
626 |
+
with open(out / f"{stem}_decoder_{k}.txt", "w") as f:
|
627 |
+
f.write(str(v))
|
628 |
+
|
629 |
+
# painter
|
630 |
+
if get_num_params(trainer.G.painter) > 0:
|
631 |
+
with open(out / f"{stem}_painter.txt", "w") as f:
|
632 |
+
f.write(str(trainer.G.painter))
|
633 |
+
|
634 |
+
# discriminators
|
635 |
+
if get_num_params(trainer.D) > 0:
|
636 |
+
for k, v in trainer.D.items():
|
637 |
+
with open(out / f"{stem}_discriminator_{k}.txt", "w") as f:
|
638 |
+
f.write(str(v))
|
639 |
+
|
640 |
+
with io.StringIO() as buf, redirect_stdout(buf):
|
641 |
+
print_num_parameters(trainer)
|
642 |
+
output = buf.getvalue()
|
643 |
+
with open(out / "archi_num_params.txt", "w") as f:
|
644 |
+
f.write(output)
|
645 |
+
|
646 |
+
|
647 |
+
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
|
648 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
649 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
650 |
+
|
651 |
+
grid = (
|
652 |
+
torch.stack(
|
653 |
+
torch.meshgrid(
|
654 |
+
torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])
|
655 |
+
),
|
656 |
+
dim=-1,
|
657 |
+
)
|
658 |
+
% 1
|
659 |
+
)
|
660 |
+
angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
|
661 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
662 |
+
|
663 |
+
tile_grads = (
|
664 |
+
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
665 |
+
.repeat_interleave(d[0], 0)
|
666 |
+
.repeat_interleave(d[1], 1)
|
667 |
+
)
|
668 |
+
dot = lambda grad, shift: ( # noqa: E731
|
669 |
+
torch.stack(
|
670 |
+
(
|
671 |
+
grid[: shape[0], : shape[1], 0] + shift[0],
|
672 |
+
grid[: shape[0], : shape[1], 1] + shift[1],
|
673 |
+
),
|
674 |
+
dim=-1,
|
675 |
+
)
|
676 |
+
* grad[: shape[0], : shape[1]]
|
677 |
+
).sum(dim=-1)
|
678 |
+
|
679 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
680 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
681 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
682 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
683 |
+
t = fade(grid[: shape[0], : shape[1]])
|
684 |
+
return math.sqrt(2) * torch.lerp(
|
685 |
+
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
|
686 |
+
)
|
687 |
+
|
688 |
+
|
689 |
+
def mix_noise(x, mask, res=(8, 3), weight=0.1):
|
690 |
+
noise = rand_perlin_2d(x.shape[-2:], res).unsqueeze(0).unsqueeze(0).to(x.device)
|
691 |
+
noise = noise - noise.min()
|
692 |
+
mask = mask.repeat(1, 3, 1, 1).to(x.device).to(torch.float16)
|
693 |
+
y = mask * (weight * noise + (1 - weight) * x) + (1 - mask) * x
|
694 |
+
return y
|
695 |
+
|
696 |
+
|
697 |
+
def tensor_ims_to_np_uint8s(ims):
|
698 |
+
"""
|
699 |
+
transform a CHW of NCHW tensor into a list of np.uint8 [0, 255]
|
700 |
+
image arrays
|
701 |
+
|
702 |
+
Args:
|
703 |
+
ims (torch.Tensor | list): [description]
|
704 |
+
"""
|
705 |
+
if not isinstance(ims, list):
|
706 |
+
assert isinstance(ims, torch.Tensor)
|
707 |
+
if ims.ndim == 3:
|
708 |
+
ims = [ims]
|
709 |
+
|
710 |
+
nps = []
|
711 |
+
for t in ims:
|
712 |
+
if t.shape[0] == 3:
|
713 |
+
t = t.permute(1, 2, 0)
|
714 |
+
else:
|
715 |
+
assert t.shape[-1] == 3
|
716 |
+
|
717 |
+
n = t.cpu().numpy()
|
718 |
+
n = (n + 1) / 2 * 255
|
719 |
+
nps.append(n.astype(np.uint8))
|
720 |
+
|
721 |
+
return nps[0] if len(nps) == 1 else nps
|
climategan/utils.py
ADDED
@@ -0,0 +1,1063 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""All non-tensor utils
|
2 |
+
"""
|
3 |
+
import contextlib
|
4 |
+
import datetime
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
import shutil
|
9 |
+
import subprocess
|
10 |
+
import time
|
11 |
+
import traceback
|
12 |
+
from os.path import expandvars
|
13 |
+
from pathlib import Path
|
14 |
+
from typing import Any, List, Optional, Union
|
15 |
+
from uuid import uuid4
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import yaml
|
20 |
+
from addict import Dict
|
21 |
+
from comet_ml import Experiment
|
22 |
+
|
23 |
+
comet_kwargs = {
|
24 |
+
"auto_metric_logging": False,
|
25 |
+
"parse_args": True,
|
26 |
+
"log_env_gpu": True,
|
27 |
+
"log_env_cpu": True,
|
28 |
+
"display_summary_level": 0,
|
29 |
+
}
|
30 |
+
|
31 |
+
IMG_EXTENSIONS = set(
|
32 |
+
[".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP"]
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def resolve(path):
|
37 |
+
"""
|
38 |
+
fully resolve a path:
|
39 |
+
resolve env vars ($HOME etc.) -> expand user (~) -> make absolute
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
pathlib.Path: resolved absolute path
|
43 |
+
"""
|
44 |
+
return Path(expandvars(str(path))).expanduser().resolve()
|
45 |
+
|
46 |
+
|
47 |
+
def copy_run_files(opts: Dict) -> None:
|
48 |
+
"""
|
49 |
+
Copy the opts's sbatch_file to output_path
|
50 |
+
|
51 |
+
Args:
|
52 |
+
opts (addict.Dict): options
|
53 |
+
"""
|
54 |
+
if opts.sbatch_file:
|
55 |
+
p = resolve(opts.sbatch_file)
|
56 |
+
if p.exists():
|
57 |
+
o = resolve(opts.output_path)
|
58 |
+
if o.exists():
|
59 |
+
shutil.copyfile(p, o / p.name)
|
60 |
+
if opts.exp_file:
|
61 |
+
p = resolve(opts.exp_file)
|
62 |
+
if p.exists():
|
63 |
+
o = resolve(opts.output_path)
|
64 |
+
if o.exists():
|
65 |
+
shutil.copyfile(p, o / p.name)
|
66 |
+
|
67 |
+
|
68 |
+
def merge(
|
69 |
+
source: Union[dict, Dict], destination: Union[dict, Dict]
|
70 |
+
) -> Union[dict, Dict]:
|
71 |
+
"""
|
72 |
+
run me with nosetests --with-doctest file.py
|
73 |
+
>>> a = { 'first' : { 'all_rows' : { 'pass' : 'dog', 'number' : '1' } } }
|
74 |
+
>>> b = { 'first' : { 'all_rows' : { 'fail' : 'cat', 'number' : '5' } } }
|
75 |
+
>>> merge(b, a) == {
|
76 |
+
'first' : {
|
77 |
+
'all_rows' : { '
|
78 |
+
pass' : 'dog',
|
79 |
+
'fail' : 'cat',
|
80 |
+
'number' : '5'
|
81 |
+
}
|
82 |
+
}
|
83 |
+
}
|
84 |
+
True
|
85 |
+
"""
|
86 |
+
for key, value in source.items():
|
87 |
+
try:
|
88 |
+
if isinstance(value, dict):
|
89 |
+
# get node or create one
|
90 |
+
node = destination.setdefault(key, {})
|
91 |
+
merge(value, node)
|
92 |
+
else:
|
93 |
+
if isinstance(destination, dict):
|
94 |
+
destination[key] = value
|
95 |
+
else:
|
96 |
+
destination = {key: value}
|
97 |
+
except TypeError as e:
|
98 |
+
print(traceback.format_exc())
|
99 |
+
print(">>>", source)
|
100 |
+
print(">>>", destination)
|
101 |
+
print(">>>", key)
|
102 |
+
print(">>>", value)
|
103 |
+
raise Exception(e)
|
104 |
+
|
105 |
+
return destination
|
106 |
+
|
107 |
+
|
108 |
+
def load_opts(
|
109 |
+
path: Optional[Union[str, Path]] = None,
|
110 |
+
default: Optional[Union[str, Path, dict, Dict]] = None,
|
111 |
+
commandline_opts: Optional[Union[Dict, dict]] = None,
|
112 |
+
) -> Dict:
|
113 |
+
"""Loadsize a configuration Dict from 2 files:
|
114 |
+
1. default files with shared values across runs and users
|
115 |
+
2. an overriding file with run- and user-specific values
|
116 |
+
|
117 |
+
Args:
|
118 |
+
path (pathlib.Path): where to find the overriding configuration
|
119 |
+
default (pathlib.Path, optional): Where to find the default opts.
|
120 |
+
Defaults to None. In which case it is assumed to be a default config
|
121 |
+
which needs processing such as setting default values for lambdas and gen
|
122 |
+
fields
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
addict.Dict: options dictionnary, with overwritten default values
|
126 |
+
"""
|
127 |
+
|
128 |
+
if path is None and default is None:
|
129 |
+
path = (
|
130 |
+
resolve(Path(__file__)).parent.parent
|
131 |
+
/ "shared"
|
132 |
+
/ "trainer"
|
133 |
+
/ "defaults.yaml"
|
134 |
+
)
|
135 |
+
|
136 |
+
if path:
|
137 |
+
path = resolve(path)
|
138 |
+
|
139 |
+
if default is None:
|
140 |
+
default_opts = {}
|
141 |
+
else:
|
142 |
+
if isinstance(default, (str, Path)):
|
143 |
+
with open(default, "r") as f:
|
144 |
+
default_opts = yaml.safe_load(f)
|
145 |
+
else:
|
146 |
+
default_opts = dict(default)
|
147 |
+
|
148 |
+
if path is None:
|
149 |
+
overriding_opts = {}
|
150 |
+
else:
|
151 |
+
with open(path, "r") as f:
|
152 |
+
overriding_opts = yaml.safe_load(f) or {}
|
153 |
+
|
154 |
+
opts = Dict(merge(overriding_opts, default_opts))
|
155 |
+
|
156 |
+
if commandline_opts is not None and isinstance(commandline_opts, dict):
|
157 |
+
opts = Dict(merge(commandline_opts, opts))
|
158 |
+
|
159 |
+
if opts.train.kitti.pretrained:
|
160 |
+
assert "kitti" in opts.data.files.train
|
161 |
+
assert "kitti" in opts.data.files.val
|
162 |
+
assert opts.train.kitti.epochs > 0
|
163 |
+
|
164 |
+
opts.domains = []
|
165 |
+
if "m" in opts.tasks or "s" in opts.tasks or "d" in opts.tasks:
|
166 |
+
opts.domains.extend(["r", "s"])
|
167 |
+
if "p" in opts.tasks:
|
168 |
+
opts.domains.append("rf")
|
169 |
+
if opts.train.kitti.pretrain:
|
170 |
+
opts.domains.append("kitti")
|
171 |
+
|
172 |
+
opts.domains = list(set(opts.domains))
|
173 |
+
|
174 |
+
if "s" in opts.tasks:
|
175 |
+
if opts.gen.encoder.architecture != opts.gen.s.architecture:
|
176 |
+
print(
|
177 |
+
"WARNING: segmentation encoder and decoder architectures do not match"
|
178 |
+
)
|
179 |
+
print(
|
180 |
+
"Encoder: {} <> Decoder: {}".format(
|
181 |
+
opts.gen.encoder.architecture, opts.gen.s.architecture
|
182 |
+
)
|
183 |
+
)
|
184 |
+
if opts.gen.m.use_spade:
|
185 |
+
if "d" not in opts.tasks or "s" not in opts.tasks:
|
186 |
+
raise ValueError(
|
187 |
+
"opts.gen.m.use_spade is True so tasks MUST include"
|
188 |
+
+ "both d and s, but received {}".format(opts.tasks)
|
189 |
+
)
|
190 |
+
if opts.gen.d.classify.enable:
|
191 |
+
raise ValueError(
|
192 |
+
"opts.gen.m.use_spade is True but using D as a classifier"
|
193 |
+
+ " which is a non-implemented combination"
|
194 |
+
)
|
195 |
+
|
196 |
+
if opts.gen.s.depth_feat_fusion is True or opts.gen.s.depth_dada_fusion is True:
|
197 |
+
opts.gen.s.use_dada = True
|
198 |
+
|
199 |
+
events_path = (
|
200 |
+
resolve(Path(__file__)).parent.parent / "shared" / "trainer" / "events.yaml"
|
201 |
+
)
|
202 |
+
if events_path.exists():
|
203 |
+
with events_path.open("r") as f:
|
204 |
+
events_dict = yaml.safe_load(f)
|
205 |
+
events_dict = Dict(events_dict)
|
206 |
+
opts.events = events_dict
|
207 |
+
|
208 |
+
return set_data_paths(opts)
|
209 |
+
|
210 |
+
|
211 |
+
def set_data_paths(opts: Dict) -> Dict:
|
212 |
+
"""Update the data files paths in data.files.train and data.files.val
|
213 |
+
from data.files.base
|
214 |
+
|
215 |
+
Args:
|
216 |
+
opts (addict.Dict): options
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
addict.Dict: updated options
|
220 |
+
"""
|
221 |
+
|
222 |
+
for mode in ["train", "val"]:
|
223 |
+
for domain in opts.data.files[mode]:
|
224 |
+
if opts.data.files.base and not opts.data.files[mode][domain].startswith(
|
225 |
+
"/"
|
226 |
+
):
|
227 |
+
opts.data.files[mode][domain] = str(
|
228 |
+
Path(opts.data.files.base) / opts.data.files[mode][domain]
|
229 |
+
)
|
230 |
+
assert Path(
|
231 |
+
opts.data.files[mode][domain]
|
232 |
+
).exists(), "Cannot find {}".format(str(opts.data.files[mode][domain]))
|
233 |
+
|
234 |
+
return opts
|
235 |
+
|
236 |
+
|
237 |
+
def load_test_opts(test_file_path: str = "config/trainer/local_tests.yaml") -> Dict:
|
238 |
+
"""Returns the special opts set up for local tests
|
239 |
+
Args:
|
240 |
+
test_file_path (str, optional): Name of the file located in config/
|
241 |
+
Defaults to "local_tests.yaml".
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
addict.Dict: Opts loaded from defaults.yaml and updated from test_file_path
|
245 |
+
"""
|
246 |
+
return load_opts(
|
247 |
+
Path(__file__).parent.parent / f"{test_file_path}",
|
248 |
+
default=Path(__file__).parent.parent / "shared/trainer/defaults.yaml",
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
def get_git_revision_hash() -> str:
|
253 |
+
"""Get current git hash the code is run from
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
str: git hash
|
257 |
+
"""
|
258 |
+
try:
|
259 |
+
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
|
260 |
+
except Exception as e:
|
261 |
+
return str(e)
|
262 |
+
|
263 |
+
|
264 |
+
def get_git_branch() -> str:
|
265 |
+
"""Get current git branch name
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
str: git branch name
|
269 |
+
"""
|
270 |
+
try:
|
271 |
+
return (
|
272 |
+
subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
273 |
+
.decode()
|
274 |
+
.strip()
|
275 |
+
)
|
276 |
+
except Exception as e:
|
277 |
+
return str(e)
|
278 |
+
|
279 |
+
|
280 |
+
def kill_job(id: Union[int, str]) -> None:
|
281 |
+
subprocess.check_output(["scancel", str(id)])
|
282 |
+
|
283 |
+
|
284 |
+
def write_hash(path: Union[str, Path]) -> None:
|
285 |
+
hash_code = get_git_revision_hash()
|
286 |
+
with open(path, "w") as f:
|
287 |
+
f.write(hash_code)
|
288 |
+
|
289 |
+
|
290 |
+
def shortuid():
|
291 |
+
return str(uuid4()).split("-")[0]
|
292 |
+
|
293 |
+
|
294 |
+
def datenowshort():
|
295 |
+
"""
|
296 |
+
>>> a = str(datetime.datetime.now())
|
297 |
+
>>> print(a)
|
298 |
+
'2021-02-25 11:34:50.188072'
|
299 |
+
>>> print(a[5:].split(".")[0].replace(" ", "_"))
|
300 |
+
'02-25_11:35:41'
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
str: month-day_h:m:s
|
304 |
+
"""
|
305 |
+
return str(datetime.datetime.now())[5:].split(".")[0].replace(" ", "_")
|
306 |
+
|
307 |
+
|
308 |
+
def get_increased_path(path: Union[str, Path], use_date: bool = False) -> Path:
|
309 |
+
"""Returns an increased path: if dir exists, returns `dir (1)`.
|
310 |
+
If `dir (i)` exists, returns `dir (max(i) + 1)`
|
311 |
+
|
312 |
+
get_increased_path("test").mkdir() creates `test/`
|
313 |
+
then
|
314 |
+
get_increased_path("test").mkdir() creates `test (1)/`
|
315 |
+
etc.
|
316 |
+
if `test (3)/` exists but not `test (2)/`, `test (4)/` is created so that indexes
|
317 |
+
always increase
|
318 |
+
|
319 |
+
Args:
|
320 |
+
path (str or pathlib.Path): the file/directory which may already exist and would
|
321 |
+
need to be increased
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
pathlib.Path: increased path
|
325 |
+
"""
|
326 |
+
fp = resolve(path)
|
327 |
+
if not fp.exists():
|
328 |
+
return fp
|
329 |
+
|
330 |
+
if fp.is_file():
|
331 |
+
if not use_date:
|
332 |
+
while fp.exists():
|
333 |
+
fp = fp.parent / f"{fp.stem}--{shortuid()}{fp.suffix}"
|
334 |
+
return fp
|
335 |
+
else:
|
336 |
+
while fp.exists():
|
337 |
+
time.sleep(0.5)
|
338 |
+
fp = fp.parent / f"{fp.stem}--{datenowshort()}{fp.suffix}"
|
339 |
+
return fp
|
340 |
+
|
341 |
+
if not use_date:
|
342 |
+
while fp.exists():
|
343 |
+
fp = fp.parent / f"{fp.name}--{shortuid()}"
|
344 |
+
return fp
|
345 |
+
else:
|
346 |
+
while fp.exists():
|
347 |
+
time.sleep(0.5)
|
348 |
+
fp = fp.parent / f"{fp.name}--{datenowshort()}"
|
349 |
+
return fp
|
350 |
+
|
351 |
+
# vals = []
|
352 |
+
# for n in fp.parent.glob("{}*".format(fp.stem)):
|
353 |
+
# if re.match(r".+\(\d+\)", str(n.name)) is not None:
|
354 |
+
# name = str(n.name)
|
355 |
+
# start = name.index("(")
|
356 |
+
# end = name.index(")")
|
357 |
+
# vals.append(int(name[start + 1 : end]))
|
358 |
+
# if vals:
|
359 |
+
# ext = " ({})".format(max(vals) + 1)
|
360 |
+
# elif fp.exists():
|
361 |
+
# ext = " (1)"
|
362 |
+
# else:
|
363 |
+
# ext = ""
|
364 |
+
# return fp.parent / (fp.stem + ext + fp.suffix)
|
365 |
+
|
366 |
+
|
367 |
+
def env_to_path(path: str) -> str:
|
368 |
+
"""Transorms an environment variable mention in a json
|
369 |
+
into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds
|
370 |
+
|
371 |
+
Args:
|
372 |
+
path (str): path potentially containing the env variable
|
373 |
+
|
374 |
+
"""
|
375 |
+
path_elements = path.split("/")
|
376 |
+
new_path = []
|
377 |
+
for el in path_elements:
|
378 |
+
if "$" in el:
|
379 |
+
new_path.append(os.environ[el.replace("$", "")])
|
380 |
+
else:
|
381 |
+
new_path.append(el)
|
382 |
+
return "/".join(new_path)
|
383 |
+
|
384 |
+
|
385 |
+
def flatten_opts(opts: Dict) -> dict:
|
386 |
+
"""Flattens a multi-level addict.Dict or native dictionnary into a single
|
387 |
+
level native dict with string keys representing the keys sequence to reach
|
388 |
+
a value in the original argument.
|
389 |
+
|
390 |
+
d = addict.Dict()
|
391 |
+
d.a.b.c = 2
|
392 |
+
d.a.b.d = 3
|
393 |
+
d.a.e = 4
|
394 |
+
d.f = 5
|
395 |
+
flatten_opts(d)
|
396 |
+
>>> {
|
397 |
+
"a.b.c": 2,
|
398 |
+
"a.b.d": 3,
|
399 |
+
"a.e": 4,
|
400 |
+
"f": 5,
|
401 |
+
}
|
402 |
+
|
403 |
+
Args:
|
404 |
+
opts (addict.Dict or dict): addict dictionnary to flatten
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
dict: flattened dictionnary
|
408 |
+
"""
|
409 |
+
values_list = []
|
410 |
+
|
411 |
+
def p(d, prefix="", vals=[]):
|
412 |
+
for k, v in d.items():
|
413 |
+
if isinstance(v, (Dict, dict)):
|
414 |
+
p(v, prefix + k + ".", vals)
|
415 |
+
elif isinstance(v, list):
|
416 |
+
if v and isinstance(v[0], (Dict, dict)):
|
417 |
+
for i, m in enumerate(v):
|
418 |
+
p(m, prefix + k + "." + str(i) + ".", vals)
|
419 |
+
else:
|
420 |
+
vals.append((prefix + k, str(v)))
|
421 |
+
else:
|
422 |
+
if isinstance(v, Path):
|
423 |
+
v = str(v)
|
424 |
+
vals.append((prefix + k, v))
|
425 |
+
|
426 |
+
p(opts, vals=values_list)
|
427 |
+
return dict(values_list)
|
428 |
+
|
429 |
+
|
430 |
+
def get_comet_rest_api_key(
|
431 |
+
path_to_config_file: Optional[Union[str, Path]] = None
|
432 |
+
) -> str:
|
433 |
+
"""Gets a comet.ml rest_api_key in the following order:
|
434 |
+
* config file specified as argument
|
435 |
+
* environment variable
|
436 |
+
* .comet.config file in the current working diretory
|
437 |
+
* .comet.config file in your home
|
438 |
+
|
439 |
+
config files must have a line like `rest_api_key=<some api key>`
|
440 |
+
|
441 |
+
Args:
|
442 |
+
path_to_config_file (str or pathlib.Path, optional): config_file to use.
|
443 |
+
Defaults to None.
|
444 |
+
|
445 |
+
Raises:
|
446 |
+
ValueError: can't find a file
|
447 |
+
ValueError: can't find the key in a file
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
str: your comet rest_api_key
|
451 |
+
"""
|
452 |
+
if "COMET_REST_API_KEY" in os.environ and path_to_config_file is None:
|
453 |
+
return os.environ["COMET_REST_API_KEY"]
|
454 |
+
if path_to_config_file is not None:
|
455 |
+
p = resolve(path_to_config_file)
|
456 |
+
else:
|
457 |
+
p = Path() / ".comet.config"
|
458 |
+
if not p.exists():
|
459 |
+
p = Path.home() / ".comet.config"
|
460 |
+
if not p.exists():
|
461 |
+
raise ValueError("Unable to find your COMET_REST_API_KEY")
|
462 |
+
with p.open("r") as f:
|
463 |
+
for keys in f:
|
464 |
+
if "rest_api_key" in keys:
|
465 |
+
return keys.strip().split("=")[-1].strip()
|
466 |
+
raise ValueError("Unable to find your COMET_REST_API_KEY in {}".format(str(p)))
|
467 |
+
|
468 |
+
|
469 |
+
def get_files(dirName: str) -> list:
|
470 |
+
# create a list of file and sub directories
|
471 |
+
files = sorted(os.listdir(dirName))
|
472 |
+
all_files = list()
|
473 |
+
for entry in files:
|
474 |
+
fullPath = os.path.join(dirName, entry)
|
475 |
+
if os.path.isdir(fullPath):
|
476 |
+
all_files = all_files + get_files(fullPath)
|
477 |
+
else:
|
478 |
+
all_files.append(fullPath)
|
479 |
+
|
480 |
+
return all_files
|
481 |
+
|
482 |
+
|
483 |
+
def make_json_file(
|
484 |
+
tasks: List[str],
|
485 |
+
addresses: List[str], # for windows user, use "\\" instead of using "/"
|
486 |
+
json_names: List[str] = ["train_jsonfile.json", "val_jsonfile.json"],
|
487 |
+
splitter: str = "/",
|
488 |
+
pourcentage_val: float = 0.15,
|
489 |
+
) -> None:
|
490 |
+
"""
|
491 |
+
How to use it?
|
492 |
+
e.g.
|
493 |
+
make_json_file(['x','m','d'], [
|
494 |
+
'/network/tmp1/ccai/data/munit_dataset/trainA_size_1200/',
|
495 |
+
'/network/tmp1/ccai/data/munit_dataset/seg_trainA_size_1200/',
|
496 |
+
'/network/tmp1/ccai/data/munit_dataset/trainA_megadepth_resized/'
|
497 |
+
], ["train_r.json", "val_r.json"])
|
498 |
+
|
499 |
+
Args:
|
500 |
+
tasks (list): the list of image type like 'x', 'm', 'd', etc.
|
501 |
+
addresses (list): the list of the corresponding address of the
|
502 |
+
image type mentioned in tasks
|
503 |
+
json_names (list): names for the json files, train being first
|
504 |
+
(e.g. : ["train_r.json", "val_r.json"])
|
505 |
+
splitter (str, optional): The path separator for the current OS.
|
506 |
+
Defaults to '/'.
|
507 |
+
pourcentage_val: pourcentage of files to go in validation set
|
508 |
+
"""
|
509 |
+
assert len(tasks) == len(addresses), "keys and addresses must have the same length!"
|
510 |
+
|
511 |
+
files = [get_files(addresses[j]) for j in range(len(tasks))]
|
512 |
+
n_files_val = int(pourcentage_val * len(files[0]))
|
513 |
+
n_files_train = len(files[0]) - n_files_val
|
514 |
+
filenames = [files[0][:n_files_train], files[0][-n_files_val:]]
|
515 |
+
|
516 |
+
file_address_map = {
|
517 |
+
tasks[j]: {
|
518 |
+
".".join(file.split(splitter)[-1].split(".")[:-1]): file
|
519 |
+
for file in files[j]
|
520 |
+
}
|
521 |
+
for j in range(len(tasks))
|
522 |
+
}
|
523 |
+
# The tasks of the file_address_map are like 'x', 'm', 'd'...
|
524 |
+
# The values of the file_address_map are a dictionary whose tasks are the
|
525 |
+
# filenames without extension whose values are the path of the filename
|
526 |
+
# e.g. file_address_map =
|
527 |
+
# {'x': {'A': 'path/to/trainA_size_1200/A.png', ...},
|
528 |
+
# 'm': {'A': 'path/to/seg_trainA_size_1200/A.jpg',...}
|
529 |
+
# 'd': {'A': 'path/to/trainA_megadepth_resized/A.bmp',...}
|
530 |
+
# ...}
|
531 |
+
|
532 |
+
for i, json_name in enumerate(json_names):
|
533 |
+
dicts = []
|
534 |
+
for j in range(len(filenames[i])):
|
535 |
+
file = filenames[i][j]
|
536 |
+
filename = file.split(splitter)[-1] # the filename with 'x' extension
|
537 |
+
filename_ = ".".join(
|
538 |
+
filename.split(".")[:-1]
|
539 |
+
) # the filename without extension
|
540 |
+
tmp_dict = {}
|
541 |
+
for k in range(len(tasks)):
|
542 |
+
tmp_dict[tasks[k]] = file_address_map[tasks[k]][filename_]
|
543 |
+
dicts.append(tmp_dict)
|
544 |
+
with open(json_name, "w", encoding="utf-8") as outfile:
|
545 |
+
json.dump(dicts, outfile, ensure_ascii=False)
|
546 |
+
|
547 |
+
|
548 |
+
def append_task_to_json(
|
549 |
+
path_to_json: Union[str, Path],
|
550 |
+
path_to_new_json: Union[str, Path],
|
551 |
+
path_to_new_images_dir: Union[str, Path],
|
552 |
+
new_task_name: str,
|
553 |
+
):
|
554 |
+
"""Add all files for a task to an existing json file by creating a new json file
|
555 |
+
in the specified path.
|
556 |
+
Assumes that the files for the new task have exactly the same names as the ones
|
557 |
+
for the other tasks
|
558 |
+
|
559 |
+
Args:
|
560 |
+
path_to_json: complete path to the json file to modify
|
561 |
+
path_to_new_json: complete path to the new json file to be created
|
562 |
+
path_to_new_images_dir: complete path of the directory where to find the
|
563 |
+
images for the new task
|
564 |
+
new_task_name: name of the new task
|
565 |
+
|
566 |
+
e.g:
|
567 |
+
append_json(
|
568 |
+
"/network/tmp1/ccai/data/climategan/seg/train_r.json",
|
569 |
+
"/network/tmp1/ccai/data/climategan/seg/train_r_new.json"
|
570 |
+
"/network/tmp1/ccai/data/munit_dataset/trainA_seg_HRNet/unity_labels",
|
571 |
+
"s",
|
572 |
+
)
|
573 |
+
"""
|
574 |
+
ims_list = None
|
575 |
+
if path_to_json:
|
576 |
+
path_to_json = Path(path_to_json).resolve()
|
577 |
+
with open(path_to_json, "r") as f:
|
578 |
+
ims_list = json.load(f)
|
579 |
+
|
580 |
+
files = get_files(path_to_new_images_dir)
|
581 |
+
|
582 |
+
if ims_list is None:
|
583 |
+
raise ValueError(f"Could not find the list in {path_to_json}")
|
584 |
+
|
585 |
+
new_ims_list = [None] * len(ims_list)
|
586 |
+
for i, im_dict in enumerate(ims_list):
|
587 |
+
new_ims_list[i] = {}
|
588 |
+
for task, path in im_dict.items():
|
589 |
+
new_ims_list[i][task] = path
|
590 |
+
|
591 |
+
for i, im_dict in enumerate(ims_list):
|
592 |
+
for task, path in im_dict.items():
|
593 |
+
file_name = os.path.splitext(path)[0] # removes extension
|
594 |
+
file_name = file_name.rsplit("/", 1)[-1] # only the file_name
|
595 |
+
file_found = False
|
596 |
+
for file_path in files:
|
597 |
+
if file_name in file_path:
|
598 |
+
file_found = True
|
599 |
+
new_ims_list[i][new_task_name] = file_path
|
600 |
+
break
|
601 |
+
if file_found:
|
602 |
+
break
|
603 |
+
else:
|
604 |
+
print("Error! File ", file_name, "not found in directory!")
|
605 |
+
return
|
606 |
+
|
607 |
+
with open(path_to_new_json, "w", encoding="utf-8") as f:
|
608 |
+
json.dump(new_ims_list, f, ensure_ascii=False)
|
609 |
+
|
610 |
+
|
611 |
+
def sum_dict(dict1: Union[dict, Dict], dict2: Union[Dict, dict]) -> Union[dict, Dict]:
|
612 |
+
"""Add dict2 into dict1"""
|
613 |
+
for k, v in dict2.items():
|
614 |
+
if not isinstance(v, dict):
|
615 |
+
dict1[k] += v
|
616 |
+
else:
|
617 |
+
sum_dict(dict1[k], dict2[k])
|
618 |
+
return dict1
|
619 |
+
|
620 |
+
|
621 |
+
def div_dict(dict1: Union[dict, Dict], div_by: float) -> dict:
|
622 |
+
"""Divide elements of dict1 by div_by"""
|
623 |
+
for k, v in dict1.items():
|
624 |
+
if not isinstance(v, dict):
|
625 |
+
dict1[k] /= div_by
|
626 |
+
else:
|
627 |
+
div_dict(dict1[k], div_by)
|
628 |
+
return dict1
|
629 |
+
|
630 |
+
|
631 |
+
def comet_id_from_url(url: str) -> Optional[str]:
|
632 |
+
"""
|
633 |
+
Get comet exp id from its url:
|
634 |
+
https://www.comet.ml/vict0rsch/climategan/2a1a4a96afe848218c58ac4e47c5375f
|
635 |
+
-> 2a1a4a96afe848218c58ac4e47c5375f
|
636 |
+
|
637 |
+
Args:
|
638 |
+
url (str): comet exp url
|
639 |
+
|
640 |
+
Returns:
|
641 |
+
str: comet exp id
|
642 |
+
"""
|
643 |
+
try:
|
644 |
+
ids = url.split("/")
|
645 |
+
ids = [i for i in ids if i]
|
646 |
+
return ids[-1]
|
647 |
+
except Exception:
|
648 |
+
return None
|
649 |
+
|
650 |
+
|
651 |
+
@contextlib.contextmanager
|
652 |
+
def temp_np_seed(seed: Optional[int]) -> None:
|
653 |
+
"""
|
654 |
+
Set temporary numpy seed:
|
655 |
+
with temp_np_seed(123):
|
656 |
+
np.random.permutation(3)
|
657 |
+
|
658 |
+
Args:
|
659 |
+
seed (int): temporary numpy seed
|
660 |
+
"""
|
661 |
+
state = np.random.get_state()
|
662 |
+
np.random.seed(seed)
|
663 |
+
try:
|
664 |
+
yield
|
665 |
+
finally:
|
666 |
+
np.random.set_state(state)
|
667 |
+
|
668 |
+
|
669 |
+
def get_display_indices(opts: Dict, domain: str, length: int) -> list:
|
670 |
+
"""
|
671 |
+
Compute the index of images to use for comet logging:
|
672 |
+
if opts.comet.display_indices is an int, and domain is real:
|
673 |
+
return range(int)
|
674 |
+
if opts.comet.display_indices is an int, and domain is sim:
|
675 |
+
return permutation(length)[:int]
|
676 |
+
if opts.comet.display_indices is a list:
|
677 |
+
return list
|
678 |
+
|
679 |
+
otherwise return []
|
680 |
+
|
681 |
+
|
682 |
+
Args:
|
683 |
+
opts (addict.Dict): options
|
684 |
+
domain (str): domain for those indices
|
685 |
+
length (int): length of dataset for the permutation
|
686 |
+
|
687 |
+
Returns:
|
688 |
+
list(int): The indices to display
|
689 |
+
"""
|
690 |
+
if domain == "rf":
|
691 |
+
dsize = max([opts.comet.display_size, opts.train.fid.get("n_images", 0)])
|
692 |
+
else:
|
693 |
+
dsize = opts.comet.display_size
|
694 |
+
if dsize > length:
|
695 |
+
print(
|
696 |
+
f"Warning: dataset is smaller ({length} images) "
|
697 |
+
+ f"than required display indices ({dsize})."
|
698 |
+
+ f" Selecting {length} images."
|
699 |
+
)
|
700 |
+
|
701 |
+
display_indices = []
|
702 |
+
assert isinstance(dsize, (int, list)), "Unknown display size {}".format(dsize)
|
703 |
+
if isinstance(dsize, int):
|
704 |
+
assert dsize >= 0, "Display size cannot be < 0"
|
705 |
+
with temp_np_seed(123):
|
706 |
+
display_indices = list(np.random.permutation(length)[:dsize])
|
707 |
+
elif isinstance(dsize, list):
|
708 |
+
display_indices = dsize
|
709 |
+
|
710 |
+
if not display_indices:
|
711 |
+
print("Warning: no display indices (utils.get_display_indices)")
|
712 |
+
|
713 |
+
return display_indices
|
714 |
+
|
715 |
+
|
716 |
+
def get_latest_path(path: Union[str, Path]) -> Path:
|
717 |
+
"""
|
718 |
+
Get the file/dir with largest increment i as `file (i).ext`
|
719 |
+
|
720 |
+
Args:
|
721 |
+
path (str or pathlib.Path): base pattern
|
722 |
+
|
723 |
+
Returns:
|
724 |
+
Path: path found
|
725 |
+
"""
|
726 |
+
p = Path(path).resolve()
|
727 |
+
s = p.stem
|
728 |
+
e = p.suffix
|
729 |
+
files = list(p.parent.glob(f"{s}*(*){e}"))
|
730 |
+
indices = list(p.parent.glob(f"{s}*(*){e}"))
|
731 |
+
indices = list(map(lambda f: f.name, indices))
|
732 |
+
indices = list(map(lambda x: re.findall(r"\((.*?)\)", x)[-1], indices))
|
733 |
+
indices = list(map(int, indices))
|
734 |
+
if not indices:
|
735 |
+
f = p
|
736 |
+
else:
|
737 |
+
f = files[np.argmax(indices)]
|
738 |
+
return f
|
739 |
+
|
740 |
+
|
741 |
+
def get_existing_jobID(output_path: Path) -> str:
|
742 |
+
"""
|
743 |
+
If the opts in output_path have a jobID, return it. Else, return None
|
744 |
+
|
745 |
+
Args:
|
746 |
+
output_path (pathlib.Path | str): where to look
|
747 |
+
|
748 |
+
Returns:
|
749 |
+
str | None: jobid
|
750 |
+
"""
|
751 |
+
op = Path(output_path)
|
752 |
+
if not op.exists():
|
753 |
+
return
|
754 |
+
|
755 |
+
opts_path = get_latest_path(op / "opts.yaml")
|
756 |
+
|
757 |
+
if not opts_path.exists():
|
758 |
+
return
|
759 |
+
|
760 |
+
with opts_path.open("r") as f:
|
761 |
+
opts = yaml.safe_load(f)
|
762 |
+
|
763 |
+
jobID = opts.get("jobID", None)
|
764 |
+
|
765 |
+
return jobID
|
766 |
+
|
767 |
+
|
768 |
+
def find_existing_training(opts: Dict) -> Optional[Path]:
|
769 |
+
"""
|
770 |
+
Looks in all directories like output_path.parent.glob(output_path.name*)
|
771 |
+
and compares the logged slurm job id with the current opts.jobID
|
772 |
+
|
773 |
+
If a match is found, the training should automatically continue in the
|
774 |
+
matching output directory
|
775 |
+
|
776 |
+
If no match is found, this is a new job and it should have a new output path
|
777 |
+
|
778 |
+
Args:
|
779 |
+
opts (Dict): trainer's options
|
780 |
+
|
781 |
+
Returns:
|
782 |
+
Optional[Path]: a path if a matchin jobID is found, None otherwise
|
783 |
+
"""
|
784 |
+
if opts.jobID is None:
|
785 |
+
print("WARNING: current JOBID is None")
|
786 |
+
return
|
787 |
+
|
788 |
+
print("---------- Current job id:", opts.jobID)
|
789 |
+
|
790 |
+
path = Path(opts.output_path).resolve()
|
791 |
+
parent = path.parent
|
792 |
+
name = path.name
|
793 |
+
|
794 |
+
try:
|
795 |
+
similar_dirs = [p.resolve() for p in parent.glob(f"{name}*") if p.is_dir()]
|
796 |
+
|
797 |
+
for sd in similar_dirs:
|
798 |
+
candidate_jobID = get_existing_jobID(sd)
|
799 |
+
if candidate_jobID is not None and str(opts.jobID) == str(candidate_jobID):
|
800 |
+
print(f"Found matching job id in {sd}\n")
|
801 |
+
return sd
|
802 |
+
print("Did not find a matching job id in \n {}\n".format(str(similar_dirs)))
|
803 |
+
except Exception as e:
|
804 |
+
print("ERROR: Could not resume (find_existing_training)", e)
|
805 |
+
|
806 |
+
|
807 |
+
def pprint(*args: List[Any]):
|
808 |
+
"""
|
809 |
+
Prints *args within a box of "=" characters
|
810 |
+
"""
|
811 |
+
txt = " ".join(map(str, args))
|
812 |
+
col = "====="
|
813 |
+
space = " "
|
814 |
+
head_size = 2
|
815 |
+
header = "\n".join(["=" * (len(txt) + 2 * (len(col) + len(space)))] * head_size)
|
816 |
+
empty = "{}{}{}{}{}".format(col, space, " " * (len(txt)), space, col)
|
817 |
+
print()
|
818 |
+
print(header)
|
819 |
+
print(empty)
|
820 |
+
print("{}{}{}{}{}".format(col, space, txt, space, col))
|
821 |
+
print(empty)
|
822 |
+
print(header)
|
823 |
+
print()
|
824 |
+
|
825 |
+
|
826 |
+
def get_existing_comet_id(path: str) -> Optional[str]:
|
827 |
+
"""
|
828 |
+
Returns the id of the existing comet experiment stored in path
|
829 |
+
|
830 |
+
Args:
|
831 |
+
path (str): Output pat where to look for the comet exp
|
832 |
+
|
833 |
+
Returns:
|
834 |
+
Optional[str]: comet exp's ID if any was found
|
835 |
+
"""
|
836 |
+
comet_previous_path = get_latest_path(Path(path) / "comet_url.txt")
|
837 |
+
if comet_previous_path.exists():
|
838 |
+
with comet_previous_path.open("r") as f:
|
839 |
+
url = f.read().strip()
|
840 |
+
return comet_id_from_url(url)
|
841 |
+
|
842 |
+
|
843 |
+
def get_latest_opts(path):
|
844 |
+
"""
|
845 |
+
get latest opts dumped in path if they look like *opts*.yaml
|
846 |
+
and were increased as
|
847 |
+
opts.yaml < opts (1).yaml < opts (2).yaml etc.
|
848 |
+
|
849 |
+
Args:
|
850 |
+
path (str or pathlib.Path): where to look for opts
|
851 |
+
|
852 |
+
Raises:
|
853 |
+
ValueError: If no match for *opts*.yaml is found
|
854 |
+
|
855 |
+
Returns:
|
856 |
+
addict.Dict: loaded opts
|
857 |
+
"""
|
858 |
+
path = Path(path)
|
859 |
+
opts = get_latest_path(path / "opts.yaml")
|
860 |
+
assert opts.exists()
|
861 |
+
with opts.open("r") as f:
|
862 |
+
opts = Dict(yaml.safe_load(f))
|
863 |
+
|
864 |
+
events_path = Path(__file__).parent.parent / "shared" / "trainer" / "events.yaml"
|
865 |
+
if events_path.exists():
|
866 |
+
with events_path.open("r") as f:
|
867 |
+
events_dict = yaml.safe_load(f)
|
868 |
+
events_dict = Dict(events_dict)
|
869 |
+
opts.events = events_dict
|
870 |
+
|
871 |
+
return opts
|
872 |
+
|
873 |
+
|
874 |
+
def text_to_array(text, width=640, height=40):
|
875 |
+
"""
|
876 |
+
Creates a numpy array of shape height x width x 3 with
|
877 |
+
text written on it using PIL
|
878 |
+
|
879 |
+
Args:
|
880 |
+
text (str): text to write
|
881 |
+
width (int, optional): Width of the resulting array. Defaults to 640.
|
882 |
+
height (int, optional): Height of the resulting array. Defaults to 40.
|
883 |
+
|
884 |
+
Returns:
|
885 |
+
np.ndarray: Centered text
|
886 |
+
"""
|
887 |
+
from PIL import Image, ImageDraw, ImageFont
|
888 |
+
|
889 |
+
img = Image.new("RGB", (width, height), (255, 255, 255))
|
890 |
+
try:
|
891 |
+
font = ImageFont.truetype("UnBatang.ttf", 25)
|
892 |
+
except OSError:
|
893 |
+
font = ImageFont.load_default()
|
894 |
+
|
895 |
+
d = ImageDraw.Draw(img)
|
896 |
+
text_width, text_height = d.textsize(text)
|
897 |
+
h = 40 // 2 - 3 * text_height // 2
|
898 |
+
w = width // 2 - text_width
|
899 |
+
d.text((w, h), text, font=font, fill=(30, 30, 30))
|
900 |
+
return np.array(img)
|
901 |
+
|
902 |
+
|
903 |
+
def all_texts_to_array(texts, width=640, height=40):
|
904 |
+
"""
|
905 |
+
Creates an array of texts, each of height and width specified
|
906 |
+
by the args, concatenated along their width dimension
|
907 |
+
|
908 |
+
Args:
|
909 |
+
texts (list(str)): List of texts to concatenate
|
910 |
+
width (int, optional): Individual text's width. Defaults to 640.
|
911 |
+
height (int, optional): Individual text's height. Defaults to 40.
|
912 |
+
|
913 |
+
Returns:
|
914 |
+
list: len(texts) text arrays with dims height x width x 3
|
915 |
+
"""
|
916 |
+
return [text_to_array(text, width, height) for text in texts]
|
917 |
+
|
918 |
+
|
919 |
+
class Timer:
|
920 |
+
def __init__(self, name="", store=None, precision=3, ignore=False, cuda=True):
|
921 |
+
self.name = name
|
922 |
+
self.store = store
|
923 |
+
self.precision = precision
|
924 |
+
self.ignore = ignore
|
925 |
+
self.cuda = cuda
|
926 |
+
|
927 |
+
if cuda:
|
928 |
+
self._start_event = torch.cuda.Event(enable_timing=True)
|
929 |
+
self._end_event = torch.cuda.Event(enable_timing=True)
|
930 |
+
|
931 |
+
def format(self, n):
|
932 |
+
return f"{n:.{self.precision}f}"
|
933 |
+
|
934 |
+
def __enter__(self):
|
935 |
+
"""Start a new timer as a context manager"""
|
936 |
+
if self.cuda:
|
937 |
+
self._start_event.record()
|
938 |
+
else:
|
939 |
+
self._start_time = time.perf_counter()
|
940 |
+
return self
|
941 |
+
|
942 |
+
def __exit__(self, *exc_info):
|
943 |
+
"""Stop the context manager timer"""
|
944 |
+
if self.ignore:
|
945 |
+
return
|
946 |
+
|
947 |
+
if self.cuda:
|
948 |
+
self._end_event.record()
|
949 |
+
torch.cuda.synchronize()
|
950 |
+
new_time = self._start_event.elapsed_time(self._end_event) / 1000
|
951 |
+
else:
|
952 |
+
t = time.perf_counter()
|
953 |
+
new_time = t - self._start_time
|
954 |
+
|
955 |
+
if self.store is not None:
|
956 |
+
assert isinstance(self.store, list)
|
957 |
+
self.store.append(new_time)
|
958 |
+
if self.name:
|
959 |
+
print(f"[{self.name}] Elapsed time: {self.format(new_time)}")
|
960 |
+
|
961 |
+
|
962 |
+
def get_loader_output_shape_from_opts(opts):
|
963 |
+
transforms = opts.data.transforms
|
964 |
+
|
965 |
+
t = None
|
966 |
+
for t in transforms[::-1]:
|
967 |
+
if t.name == "resize":
|
968 |
+
break
|
969 |
+
assert t is not None
|
970 |
+
|
971 |
+
if isinstance(t.new_size, Dict):
|
972 |
+
return {
|
973 |
+
task: (
|
974 |
+
t.new_size.get(task, t.new_size.default),
|
975 |
+
t.new_size.get(task, t.new_size.default),
|
976 |
+
)
|
977 |
+
for task in opts.tasks + ["x"]
|
978 |
+
}
|
979 |
+
assert isinstance(t.new_size, int)
|
980 |
+
new_size = (t.new_size, t.new_size)
|
981 |
+
return {task: new_size for task in opts.tasks + ["x"]}
|
982 |
+
|
983 |
+
|
984 |
+
def find_target_size(opts, task):
|
985 |
+
target_size = None
|
986 |
+
if isinstance(opts.data.transforms[-1].new_size, int):
|
987 |
+
target_size = opts.data.transforms[-1].new_size
|
988 |
+
else:
|
989 |
+
if task in opts.data.transforms[-1].new_size:
|
990 |
+
target_size = opts.data.transforms[-1].new_size[task]
|
991 |
+
else:
|
992 |
+
assert "default" in opts.data.transforms[-1].new_size
|
993 |
+
target_size = opts.data.transforms[-1].new_size["default"]
|
994 |
+
|
995 |
+
return target_size
|
996 |
+
|
997 |
+
|
998 |
+
def to_128(im, w_target=-1):
|
999 |
+
h, w = im.shape[:2]
|
1000 |
+
aspect_ratio = h / w
|
1001 |
+
if w_target < 0:
|
1002 |
+
w_target = w
|
1003 |
+
|
1004 |
+
nw = int(w_target / 128) * 128
|
1005 |
+
nh = int(nw * aspect_ratio / 128) * 128
|
1006 |
+
|
1007 |
+
return nh, nw
|
1008 |
+
|
1009 |
+
|
1010 |
+
def is_image_file(filename):
|
1011 |
+
"""Check that a file's name points to a known image format"""
|
1012 |
+
if isinstance(filename, Path):
|
1013 |
+
return filename.suffix in IMG_EXTENSIONS
|
1014 |
+
|
1015 |
+
return Path(filename).suffix in IMG_EXTENSIONS
|
1016 |
+
|
1017 |
+
|
1018 |
+
def find_images(path, recursive=False):
|
1019 |
+
"""
|
1020 |
+
Get a list of all images contained in a directory:
|
1021 |
+
|
1022 |
+
- path.glob("*") if not recursive
|
1023 |
+
- path.glob("**/*") if recursive
|
1024 |
+
"""
|
1025 |
+
p = Path(path)
|
1026 |
+
assert p.exists()
|
1027 |
+
assert p.is_dir()
|
1028 |
+
pattern = "*"
|
1029 |
+
if recursive:
|
1030 |
+
pattern += "*/*"
|
1031 |
+
|
1032 |
+
return [i for i in p.glob(pattern) if i.is_file() and is_image_file(i)]
|
1033 |
+
|
1034 |
+
|
1035 |
+
def cols():
|
1036 |
+
try:
|
1037 |
+
col = os.get_terminal_size().columns
|
1038 |
+
except Exception:
|
1039 |
+
col = 50
|
1040 |
+
return col
|
1041 |
+
|
1042 |
+
|
1043 |
+
def upload_images_to_exp(
|
1044 |
+
path, exp=None, project_name="climategan-eval", sleep=-1, verbose=0
|
1045 |
+
):
|
1046 |
+
ims = find_images(path)
|
1047 |
+
end = None
|
1048 |
+
c = cols()
|
1049 |
+
if verbose == 1:
|
1050 |
+
end = "\r"
|
1051 |
+
if verbose > 1:
|
1052 |
+
end = "\n"
|
1053 |
+
if exp is None:
|
1054 |
+
exp = Experiment(project_name=project_name)
|
1055 |
+
for im in ims:
|
1056 |
+
exp.log_image(str(im))
|
1057 |
+
if verbose > 0:
|
1058 |
+
if verbose == 1:
|
1059 |
+
print(" " * (c - 1), end="\r", flush=True)
|
1060 |
+
print(str(im), end=end, flush=True)
|
1061 |
+
if sleep > 0:
|
1062 |
+
time.sleep(sleep)
|
1063 |
+
return exp
|
eval_masker.py
ADDED
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Compute metrics of the performance of the masker using a set of ground-truth labels
|
3 |
+
|
4 |
+
run eval_masker.py --model "/miniscratch/_groups/ccai/checkpoints/model/"
|
5 |
+
|
6 |
+
"""
|
7 |
+
print("Imports...", end="")
|
8 |
+
import os
|
9 |
+
import os.path
|
10 |
+
from argparse import ArgumentParser
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
from comet_ml import Experiment
|
17 |
+
import torch
|
18 |
+
import yaml
|
19 |
+
from skimage.color import rgba2rgb
|
20 |
+
from skimage.io import imread, imsave
|
21 |
+
from skimage.transform import resize
|
22 |
+
from skimage.util import img_as_ubyte
|
23 |
+
from torchvision.transforms import ToTensor
|
24 |
+
|
25 |
+
from climategan.data import encode_mask_label
|
26 |
+
from climategan.eval_metrics import (
|
27 |
+
masker_classification_metrics,
|
28 |
+
get_confusion_matrix,
|
29 |
+
edges_coherence_std_min,
|
30 |
+
boxplot_metric,
|
31 |
+
clustermap_metric,
|
32 |
+
)
|
33 |
+
from climategan.transforms import PrepareTest
|
34 |
+
from climategan.trainer import Trainer
|
35 |
+
from climategan.utils import find_images
|
36 |
+
|
37 |
+
dict_metrics = {
|
38 |
+
"names": {
|
39 |
+
"tpr": "TPR, Recall, Sensitivity",
|
40 |
+
"tnr": "TNR, Specificity, Selectivity",
|
41 |
+
"fpr": "FPR",
|
42 |
+
"fpt": "False positives relative to image size",
|
43 |
+
"fnr": "FNR, Miss rate",
|
44 |
+
"fnt": "False negatives relative to image size",
|
45 |
+
"mpr": "May positive rate (MPR)",
|
46 |
+
"mnr": "May negative rate (MNR)",
|
47 |
+
"accuracy": "Accuracy (ignoring may)",
|
48 |
+
"error": "Error (ignoring may)",
|
49 |
+
"f05": "F0.05 score",
|
50 |
+
"precision": "Precision",
|
51 |
+
"edge_coherence": "Edge coherence",
|
52 |
+
"accuracy_must_may": "Accuracy (ignoring cannot)",
|
53 |
+
},
|
54 |
+
"threshold": {
|
55 |
+
"tpr": 0.95,
|
56 |
+
"tnr": 0.95,
|
57 |
+
"fpr": 0.05,
|
58 |
+
"fpt": 0.01,
|
59 |
+
"fnr": 0.05,
|
60 |
+
"fnt": 0.01,
|
61 |
+
"accuracy": 0.95,
|
62 |
+
"error": 0.05,
|
63 |
+
"f05": 0.95,
|
64 |
+
"precision": 0.95,
|
65 |
+
"edge_coherence": 0.02,
|
66 |
+
"accuracy_must_may": 0.5,
|
67 |
+
},
|
68 |
+
"key_metrics": ["f05", "error", "edge_coherence", "mnr"],
|
69 |
+
}
|
70 |
+
|
71 |
+
print("Ok.")
|
72 |
+
|
73 |
+
|
74 |
+
def parsed_args():
|
75 |
+
"""Parse and returns command-line args
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
argparse.Namespace: the parsed arguments
|
79 |
+
"""
|
80 |
+
parser = ArgumentParser()
|
81 |
+
parser.add_argument(
|
82 |
+
"--model",
|
83 |
+
type=str,
|
84 |
+
help="Path to a pre-trained model",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--images_dir",
|
88 |
+
default="/miniscratch/_groups/ccai/data/omnigan/masker-test-set/imgs",
|
89 |
+
type=str,
|
90 |
+
help="Directory containing the original test images",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--labels_dir",
|
94 |
+
default="/miniscratch/_groups/ccai/data/omnigan/masker-test-set/labels",
|
95 |
+
type=str,
|
96 |
+
help="Directory containing the labeled images",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--image_size",
|
100 |
+
default=640,
|
101 |
+
type=int,
|
102 |
+
help="The height and weight of the pre-processed images",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--max_files",
|
106 |
+
default=-1,
|
107 |
+
type=int,
|
108 |
+
help="Limit loaded samples",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--bin_value", default=0.5, type=float, help="Mask binarization threshold"
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"-y",
|
115 |
+
"--yaml",
|
116 |
+
default=None,
|
117 |
+
type=str,
|
118 |
+
help="load a yaml file to parametrize the evaluation",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"-p",
|
125 |
+
"--plot",
|
126 |
+
action="store_true",
|
127 |
+
default=False,
|
128 |
+
help="Plot masker images & their metrics overlays",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--no_paint",
|
132 |
+
action="store_true",
|
133 |
+
default=False,
|
134 |
+
help="Do not log painted images",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--write_metrics",
|
138 |
+
action="store_true",
|
139 |
+
default=False,
|
140 |
+
help="If True, write CSV file and maps images in model's path directory",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--load_metrics",
|
144 |
+
action="store_true",
|
145 |
+
default=False,
|
146 |
+
help="If True, load predictions and metrics instead of re-computing",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--prepare_torch",
|
150 |
+
action="store_true",
|
151 |
+
default=False,
|
152 |
+
help="If True, pre-process images as torch tensors",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--output_csv",
|
156 |
+
default=None,
|
157 |
+
type=str,
|
158 |
+
help="Filename of the output CSV with the metrics of all models",
|
159 |
+
)
|
160 |
+
|
161 |
+
return parser.parse_args()
|
162 |
+
|
163 |
+
|
164 |
+
def uint8(array):
|
165 |
+
return array.astype(np.uint8)
|
166 |
+
|
167 |
+
|
168 |
+
def crop_and_resize(image_path, label_path):
|
169 |
+
"""
|
170 |
+
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
|
171 |
+
is 640, then crops this resized image in its center so that the output is 640x640
|
172 |
+
without aspect ratio distortion
|
173 |
+
|
174 |
+
Args:
|
175 |
+
image_path (Path or str): Path to an image
|
176 |
+
label_path (Path or str): Path to the image's associated label
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
tuple((np.ndarray, np.ndarray)): (new image, new label)
|
180 |
+
"""
|
181 |
+
|
182 |
+
img = imread(image_path)
|
183 |
+
lab = imread(label_path)
|
184 |
+
|
185 |
+
# if img.shape[-1] == 4:
|
186 |
+
# img = uint8(rgba2rgb(img) * 255)
|
187 |
+
|
188 |
+
# TODO: remove (debug)
|
189 |
+
if img.shape[:2] != lab.shape[:2]:
|
190 |
+
print(
|
191 |
+
"\nWARNING: shape mismatch: im -> ({}) {}, lab -> ({}) {}".format(
|
192 |
+
img.shape[:2], image_path.name, lab.shape[:2], label_path.name
|
193 |
+
)
|
194 |
+
)
|
195 |
+
# breakpoint()
|
196 |
+
|
197 |
+
# resize keeping aspect ratio: smallest dim is 640
|
198 |
+
i_h, i_w = img.shape[:2]
|
199 |
+
if i_h < i_w:
|
200 |
+
i_size = (640, int(640 * i_w / i_h))
|
201 |
+
else:
|
202 |
+
i_size = (int(640 * i_h / i_w), 640)
|
203 |
+
|
204 |
+
l_h, l_w = img.shape[:2]
|
205 |
+
if l_h < l_w:
|
206 |
+
l_size = (640, int(640 * l_w / l_h))
|
207 |
+
else:
|
208 |
+
l_size = (int(640 * l_h / l_w), 640)
|
209 |
+
|
210 |
+
r_img = resize(img, i_size, preserve_range=True, anti_aliasing=True)
|
211 |
+
r_img = uint8(r_img)
|
212 |
+
|
213 |
+
r_lab = resize(lab, l_size, preserve_range=True, anti_aliasing=False, order=0)
|
214 |
+
r_lab = uint8(r_lab)
|
215 |
+
|
216 |
+
# crop in the center
|
217 |
+
H, W = r_img.shape[:2]
|
218 |
+
|
219 |
+
top = (H - 640) // 2
|
220 |
+
left = (W - 640) // 2
|
221 |
+
|
222 |
+
rc_img = r_img[top : top + 640, left : left + 640, :]
|
223 |
+
rc_lab = (
|
224 |
+
r_lab[top : top + 640, left : left + 640, :]
|
225 |
+
if r_lab.ndim == 3
|
226 |
+
else r_lab[top : top + 640, left : left + 640]
|
227 |
+
)
|
228 |
+
|
229 |
+
return rc_img, rc_lab
|
230 |
+
|
231 |
+
|
232 |
+
def plot_images(
|
233 |
+
output_filename,
|
234 |
+
img,
|
235 |
+
label,
|
236 |
+
pred,
|
237 |
+
metrics_dict,
|
238 |
+
maps_dict,
|
239 |
+
edge_coherence=-1,
|
240 |
+
pred_edge=None,
|
241 |
+
label_edge=None,
|
242 |
+
dpi=300,
|
243 |
+
alpha=0.5,
|
244 |
+
vmin=0.0,
|
245 |
+
vmax=1.0,
|
246 |
+
fontsize="xx-small",
|
247 |
+
cmap={
|
248 |
+
"fp": "Reds",
|
249 |
+
"fn": "Reds",
|
250 |
+
"may_neg": "Oranges",
|
251 |
+
"may_pos": "Purples",
|
252 |
+
"pred": "Greens",
|
253 |
+
},
|
254 |
+
):
|
255 |
+
f, axes = plt.subplots(1, 5, dpi=dpi)
|
256 |
+
|
257 |
+
# FPR (predicted mask on cannot flood)
|
258 |
+
axes[0].imshow(img)
|
259 |
+
fp_map_plt = axes[0].imshow( # noqa: F841
|
260 |
+
maps_dict["fp"], vmin=vmin, vmax=vmax, cmap=cmap["fp"], alpha=alpha
|
261 |
+
)
|
262 |
+
axes[0].axis("off")
|
263 |
+
axes[0].set_title("FPR: {:.4f}".format(metrics_dict["fpr"]), fontsize=fontsize)
|
264 |
+
|
265 |
+
# FNR (missed mask on must flood)
|
266 |
+
axes[1].imshow(img)
|
267 |
+
fn_map_plt = axes[1].imshow( # noqa: F841
|
268 |
+
maps_dict["fn"], vmin=vmin, vmax=vmax, cmap=cmap["fn"], alpha=alpha
|
269 |
+
)
|
270 |
+
axes[1].axis("off")
|
271 |
+
axes[1].set_title("FNR: {:.4f}".format(metrics_dict["fnr"]), fontsize=fontsize)
|
272 |
+
|
273 |
+
# May flood
|
274 |
+
axes[2].imshow(img)
|
275 |
+
if edge_coherence != -1:
|
276 |
+
title = "MNR: {:.2f} | MPR: {:.2f}\nEdge coh.: {:.4f}".format(
|
277 |
+
metrics_dict["mnr"], metrics_dict["mpr"], edge_coherence
|
278 |
+
)
|
279 |
+
# alpha_here = alpha / 4.
|
280 |
+
# pred_edge_plt = axes[2].imshow(
|
281 |
+
# 1.0 - pred_edge, cmap="gray", alpha=alpha_here
|
282 |
+
# )
|
283 |
+
# label_edge_plt = axes[2].imshow(
|
284 |
+
# 1.0 - label_edge, cmap="gray", alpha=alpha_here
|
285 |
+
# )
|
286 |
+
else:
|
287 |
+
title = "MNR: {:.2f} | MPR: {:.2f}".format(mnr, mpr) # noqa: F821
|
288 |
+
# alpha_here = alpha / 2.
|
289 |
+
may_neg_map_plt = axes[2].imshow( # noqa: F841
|
290 |
+
maps_dict["may_neg"], vmin=vmin, vmax=vmax, cmap=cmap["may_neg"], alpha=alpha
|
291 |
+
)
|
292 |
+
may_pos_map_plt = axes[2].imshow( # noqa: F841
|
293 |
+
maps_dict["may_pos"], vmin=vmin, vmax=vmax, cmap=cmap["may_pos"], alpha=alpha
|
294 |
+
)
|
295 |
+
axes[2].set_title(title, fontsize=fontsize)
|
296 |
+
axes[2].axis("off")
|
297 |
+
|
298 |
+
# Prediction
|
299 |
+
axes[3].imshow(img)
|
300 |
+
pred_mask = axes[3].imshow( # noqa: F841
|
301 |
+
pred, vmin=vmin, vmax=vmax, cmap=cmap["pred"], alpha=alpha
|
302 |
+
)
|
303 |
+
axes[3].set_title("Predicted mask", fontsize=fontsize)
|
304 |
+
axes[3].axis("off")
|
305 |
+
|
306 |
+
# Labels
|
307 |
+
axes[4].imshow(img)
|
308 |
+
label_mask = axes[4].imshow(label, alpha=alpha) # noqa: F841
|
309 |
+
axes[4].set_title("Labels", fontsize=fontsize)
|
310 |
+
axes[4].axis("off")
|
311 |
+
|
312 |
+
f.savefig(
|
313 |
+
output_filename,
|
314 |
+
dpi=f.dpi,
|
315 |
+
bbox_inches="tight",
|
316 |
+
facecolor="white",
|
317 |
+
transparent=False,
|
318 |
+
)
|
319 |
+
plt.close(f)
|
320 |
+
|
321 |
+
|
322 |
+
def load_ground(ground_output_path, ref_image_path):
|
323 |
+
gop = Path(ground_output_path)
|
324 |
+
rip = Path(ref_image_path)
|
325 |
+
|
326 |
+
ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list(
|
327 |
+
(gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png")
|
328 |
+
)
|
329 |
+
if len(ground_paths) == 0:
|
330 |
+
raise ValueError(
|
331 |
+
f"Could not find a ground match in {str(gop)} for image {str(rip)}"
|
332 |
+
)
|
333 |
+
elif len(ground_paths) > 1:
|
334 |
+
raise ValueError(
|
335 |
+
f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:"
|
336 |
+
+ f" {list(map(str, ground_paths))}"
|
337 |
+
)
|
338 |
+
ground_path = ground_paths[0]
|
339 |
+
_, ground = crop_and_resize(rip, ground_path)
|
340 |
+
if ground.ndim == 3:
|
341 |
+
ground = ground[:, :, 0]
|
342 |
+
ground = (ground > 0).astype(np.float32)
|
343 |
+
return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda()
|
344 |
+
|
345 |
+
|
346 |
+
def get_inferences(
|
347 |
+
image_arrays, model_path, image_paths, paint=False, bin_value=0.5, verbose=0
|
348 |
+
):
|
349 |
+
"""
|
350 |
+
Obtains the mask predictions of a model for a set of images
|
351 |
+
|
352 |
+
Parameters
|
353 |
+
----------
|
354 |
+
image_arrays : array-like
|
355 |
+
A list of (1, CH, H, W) images
|
356 |
+
|
357 |
+
image_paths: list(Path)
|
358 |
+
A list of paths for images, in the same order as image_arrays
|
359 |
+
|
360 |
+
model_path : str
|
361 |
+
The path to a pre-trained model
|
362 |
+
|
363 |
+
Returns
|
364 |
+
-------
|
365 |
+
masks : list
|
366 |
+
A list of (H, W) predicted masks
|
367 |
+
"""
|
368 |
+
device = torch.device("cuda:0")
|
369 |
+
torch.set_grad_enabled(False)
|
370 |
+
to_tensor = ToTensor()
|
371 |
+
|
372 |
+
is_ground = "ground" in Path(model_path).name
|
373 |
+
is_instagan = "instagan" in Path(model_path).name
|
374 |
+
|
375 |
+
if is_ground or is_instagan:
|
376 |
+
# we just care about he painter here
|
377 |
+
ground_path = model_path
|
378 |
+
model_path = (
|
379 |
+
"/miniscratch/_groups/ccai/experiments/runs/ablation-v1/out--38858350"
|
380 |
+
)
|
381 |
+
|
382 |
+
xs = [to_tensor(array).unsqueeze(0) for array in image_arrays]
|
383 |
+
xs = [x.to(torch.float32).to(device) for x in xs]
|
384 |
+
xs = [(x - 0.5) * 2 for x in xs]
|
385 |
+
trainer = Trainer.resume_from_path(
|
386 |
+
model_path, inference=True, new_exp=None, device=device
|
387 |
+
)
|
388 |
+
masks = []
|
389 |
+
painted = []
|
390 |
+
for idx, x in enumerate(xs):
|
391 |
+
if verbose > 0:
|
392 |
+
print(idx, "/", len(xs), end="\r")
|
393 |
+
|
394 |
+
if not is_ground and not is_instagan:
|
395 |
+
m = trainer.G.mask(x=x)
|
396 |
+
else:
|
397 |
+
m = load_ground(ground_path, image_paths[idx])
|
398 |
+
|
399 |
+
masks.append(m.squeeze().cpu())
|
400 |
+
if paint:
|
401 |
+
p = trainer.G.paint(m > bin_value, x)
|
402 |
+
painted.append(p.squeeze().cpu())
|
403 |
+
return masks, painted
|
404 |
+
|
405 |
+
|
406 |
+
if __name__ == "__main__":
|
407 |
+
# -----------------------------
|
408 |
+
# ----- Parse arguments -----
|
409 |
+
# -----------------------------
|
410 |
+
args = parsed_args()
|
411 |
+
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
|
412 |
+
|
413 |
+
# Determine output dir
|
414 |
+
try:
|
415 |
+
tmp_dir = Path(os.environ["SLURM_TMPDIR"])
|
416 |
+
except Exception as e:
|
417 |
+
print(e)
|
418 |
+
tmp_dir = Path(input("Enter tmp output directory: ")).resolve()
|
419 |
+
|
420 |
+
plot_dir = tmp_dir / "plots"
|
421 |
+
plot_dir.mkdir(parents=True, exist_ok=True)
|
422 |
+
|
423 |
+
# Build paths to data
|
424 |
+
imgs_paths = sorted(
|
425 |
+
find_images(args.images_dir, recursive=False), key=lambda x: x.name
|
426 |
+
)
|
427 |
+
labels_paths = sorted(
|
428 |
+
find_images(args.labels_dir, recursive=False),
|
429 |
+
key=lambda x: x.name.replace("_labeled.", "."),
|
430 |
+
)
|
431 |
+
if args.max_files > 0:
|
432 |
+
imgs_paths = imgs_paths[: args.max_files]
|
433 |
+
labels_paths = labels_paths[: args.max_files]
|
434 |
+
|
435 |
+
print(f"Loading {len(imgs_paths)} images and labels...")
|
436 |
+
|
437 |
+
# Pre-process images: resize + crop
|
438 |
+
# TODO: ? make cropping more flexible, not only central
|
439 |
+
if not args.prepare_torch:
|
440 |
+
ims_labs = [crop_and_resize(i, l) for i, l in zip(imgs_paths, labels_paths)]
|
441 |
+
imgs = [d[0] for d in ims_labs]
|
442 |
+
labels = [d[1] for d in ims_labs]
|
443 |
+
else:
|
444 |
+
prepare = PrepareTest()
|
445 |
+
imgs = prepare(imgs_paths, normalize=False, rescale=False)
|
446 |
+
labels = prepare(labels_paths, normalize=False, rescale=False)
|
447 |
+
|
448 |
+
imgs = [i.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for i in imgs]
|
449 |
+
labels = [
|
450 |
+
lab.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for lab in labels
|
451 |
+
]
|
452 |
+
imgs = [rgba2rgb(img) if img.shape[-1] == 4 else img for img in imgs]
|
453 |
+
print(" Done.")
|
454 |
+
|
455 |
+
# Encode labels
|
456 |
+
print("Encode labels...", end="", flush=True)
|
457 |
+
# HW label
|
458 |
+
labels = [np.squeeze(encode_mask_label(label, "flood")) for label in labels]
|
459 |
+
print("Done.")
|
460 |
+
|
461 |
+
if args.yaml:
|
462 |
+
y_path = Path(args.yaml)
|
463 |
+
assert y_path.exists()
|
464 |
+
assert y_path.suffix in {".yaml", ".yml"}
|
465 |
+
with y_path.open("r") as f:
|
466 |
+
data = yaml.safe_load(f)
|
467 |
+
assert "models" in data
|
468 |
+
|
469 |
+
evaluations = [m for m in data["models"]]
|
470 |
+
else:
|
471 |
+
evaluations = [args.model]
|
472 |
+
|
473 |
+
for e, eval_path in enumerate(evaluations):
|
474 |
+
print("\n>>>>> Evaluation", e, ":", eval_path)
|
475 |
+
print("=" * 50)
|
476 |
+
print("=" * 50)
|
477 |
+
|
478 |
+
model_metrics_path = Path(eval_path) / "eval-metrics"
|
479 |
+
model_metrics_path.mkdir(exist_ok=True)
|
480 |
+
if args.load_metrics:
|
481 |
+
f_csv = model_metrics_path / "eval_masker.csv"
|
482 |
+
pred_out = model_metrics_path / "pred"
|
483 |
+
if f_csv.exists() and pred_out.exists():
|
484 |
+
print("Skipping model because pre-computed metrics exist")
|
485 |
+
continue
|
486 |
+
|
487 |
+
# Initialize New Comet Experiment
|
488 |
+
exp = Experiment(
|
489 |
+
project_name="climategan-masker-metrics", display_summary_level=0
|
490 |
+
)
|
491 |
+
|
492 |
+
# Obtain mask predictions
|
493 |
+
# TODO: remove (debug)
|
494 |
+
print("Obtain mask predictions", end="", flush=True)
|
495 |
+
|
496 |
+
preds, painted = get_inferences(
|
497 |
+
imgs,
|
498 |
+
eval_path,
|
499 |
+
imgs_paths,
|
500 |
+
paint=not args.no_paint,
|
501 |
+
bin_value=args.bin_value,
|
502 |
+
verbose=1,
|
503 |
+
)
|
504 |
+
preds = [pred.numpy() for pred in preds]
|
505 |
+
print(" Done.")
|
506 |
+
|
507 |
+
if args.bin_value > 0:
|
508 |
+
preds = [pred > args.bin_value for pred in preds]
|
509 |
+
|
510 |
+
# Compute metrics
|
511 |
+
df = pd.DataFrame(
|
512 |
+
columns=[
|
513 |
+
"tpr",
|
514 |
+
"tpt",
|
515 |
+
"tnr",
|
516 |
+
"tnt",
|
517 |
+
"fpr",
|
518 |
+
"fpt",
|
519 |
+
"fnr",
|
520 |
+
"fnt",
|
521 |
+
"mnr",
|
522 |
+
"mpr",
|
523 |
+
"accuracy",
|
524 |
+
"error",
|
525 |
+
"precision",
|
526 |
+
"f05",
|
527 |
+
"accuracy_must_may",
|
528 |
+
"edge_coherence",
|
529 |
+
"filename",
|
530 |
+
]
|
531 |
+
)
|
532 |
+
|
533 |
+
print("Compute metrics and plot images")
|
534 |
+
for idx, (img, label, pred) in enumerate(zip(*(imgs, labels, preds))):
|
535 |
+
print(idx, "/", len(imgs), end="\r")
|
536 |
+
|
537 |
+
# Basic classification metrics
|
538 |
+
metrics_dict, maps_dict = masker_classification_metrics(
|
539 |
+
pred, label, labels_dict={"cannot": 0, "must": 1, "may": 2}
|
540 |
+
)
|
541 |
+
|
542 |
+
# Edges coherence
|
543 |
+
edge_coherence, pred_edge, label_edge = edges_coherence_std_min(pred, label)
|
544 |
+
|
545 |
+
series_dict = {
|
546 |
+
"tpr": metrics_dict["tpr"],
|
547 |
+
"tpt": metrics_dict["tpt"],
|
548 |
+
"tnr": metrics_dict["tnr"],
|
549 |
+
"tnt": metrics_dict["tnt"],
|
550 |
+
"fpr": metrics_dict["fpr"],
|
551 |
+
"fpt": metrics_dict["fpt"],
|
552 |
+
"fnr": metrics_dict["fnr"],
|
553 |
+
"fnt": metrics_dict["fnt"],
|
554 |
+
"mnr": metrics_dict["mnr"],
|
555 |
+
"mpr": metrics_dict["mpr"],
|
556 |
+
"accuracy": metrics_dict["accuracy"],
|
557 |
+
"error": metrics_dict["error"],
|
558 |
+
"precision": metrics_dict["precision"],
|
559 |
+
"f05": metrics_dict["f05"],
|
560 |
+
"accuracy_must_may": metrics_dict["accuracy_must_may"],
|
561 |
+
"edge_coherence": edge_coherence,
|
562 |
+
"filename": str(imgs_paths[idx].name),
|
563 |
+
}
|
564 |
+
df.loc[idx] = pd.Series(series_dict)
|
565 |
+
|
566 |
+
for k, v in series_dict.items():
|
567 |
+
if k == "filename":
|
568 |
+
continue
|
569 |
+
exp.log_metric(f"img_{k}", v, step=idx)
|
570 |
+
|
571 |
+
# Confusion matrix
|
572 |
+
confmat, _ = get_confusion_matrix(
|
573 |
+
metrics_dict["tpr"],
|
574 |
+
metrics_dict["tnr"],
|
575 |
+
metrics_dict["fpr"],
|
576 |
+
metrics_dict["fnr"],
|
577 |
+
metrics_dict["mnr"],
|
578 |
+
metrics_dict["mpr"],
|
579 |
+
)
|
580 |
+
confmat = np.around(confmat, decimals=3)
|
581 |
+
exp.log_confusion_matrix(
|
582 |
+
file_name=imgs_paths[idx].name + ".json",
|
583 |
+
title=imgs_paths[idx].name,
|
584 |
+
matrix=confmat,
|
585 |
+
labels=["Cannot", "Must", "May"],
|
586 |
+
row_label="Predicted",
|
587 |
+
column_label="Ground truth",
|
588 |
+
)
|
589 |
+
|
590 |
+
if args.plot:
|
591 |
+
# Plot prediction images
|
592 |
+
fig_filename = plot_dir / imgs_paths[idx].name
|
593 |
+
plot_images(
|
594 |
+
fig_filename,
|
595 |
+
img,
|
596 |
+
label,
|
597 |
+
pred,
|
598 |
+
metrics_dict,
|
599 |
+
maps_dict,
|
600 |
+
edge_coherence,
|
601 |
+
pred_edge,
|
602 |
+
label_edge,
|
603 |
+
)
|
604 |
+
exp.log_image(fig_filename)
|
605 |
+
if not args.no_paint:
|
606 |
+
masked = img * (1 - pred[..., None])
|
607 |
+
flooded = img_as_ubyte(
|
608 |
+
(painted[idx].permute(1, 2, 0).cpu().numpy() + 1) / 2
|
609 |
+
)
|
610 |
+
combined = np.concatenate([img, masked, flooded], 1)
|
611 |
+
exp.log_image(combined, imgs_paths[idx].name)
|
612 |
+
|
613 |
+
if args.write_metrics:
|
614 |
+
pred_out = model_metrics_path / "pred"
|
615 |
+
pred_out.mkdir(exist_ok=True)
|
616 |
+
imsave(
|
617 |
+
pred_out / f"{imgs_paths[idx].stem}_pred.png",
|
618 |
+
pred.astype(np.uint8),
|
619 |
+
)
|
620 |
+
for k, v in maps_dict.items():
|
621 |
+
metric_out = model_metrics_path / k
|
622 |
+
metric_out.mkdir(exist_ok=True)
|
623 |
+
imsave(
|
624 |
+
metric_out / f"{imgs_paths[idx].stem}_{k}.png",
|
625 |
+
v.astype(np.uint8),
|
626 |
+
)
|
627 |
+
|
628 |
+
# --------------------------------
|
629 |
+
# ----- END OF IMAGES LOOP -----
|
630 |
+
# --------------------------------
|
631 |
+
|
632 |
+
if args.write_metrics:
|
633 |
+
print(f"Writing metrics in {str(model_metrics_path)}")
|
634 |
+
f_csv = model_metrics_path / "eval_masker.csv"
|
635 |
+
df.to_csv(f_csv, index_label="idx")
|
636 |
+
|
637 |
+
print(" Done.")
|
638 |
+
# Summary statistics
|
639 |
+
means = df.mean(axis=0)
|
640 |
+
confmat_mean, confmat_std = get_confusion_matrix(
|
641 |
+
df.tpr, df.tnr, df.fpr, df.fnr, df.mpr, df.mnr
|
642 |
+
)
|
643 |
+
confmat_mean = np.around(confmat_mean, decimals=3)
|
644 |
+
confmat_std = np.around(confmat_std, decimals=3)
|
645 |
+
|
646 |
+
# Log to comet
|
647 |
+
exp.log_confusion_matrix(
|
648 |
+
file_name="confusion_matrix_mean.json",
|
649 |
+
title="confusion_matrix_mean.json",
|
650 |
+
matrix=confmat_mean,
|
651 |
+
labels=["Cannot", "Must", "May"],
|
652 |
+
row_label="Predicted",
|
653 |
+
column_label="Ground truth",
|
654 |
+
)
|
655 |
+
exp.log_confusion_matrix(
|
656 |
+
file_name="confusion_matrix_std.json",
|
657 |
+
title="confusion_matrix_std.json",
|
658 |
+
matrix=confmat_std,
|
659 |
+
labels=["Cannot", "Must", "May"],
|
660 |
+
row_label="Predicted",
|
661 |
+
column_label="Ground truth",
|
662 |
+
)
|
663 |
+
exp.log_metrics(dict(means))
|
664 |
+
exp.log_table("metrics.csv", df)
|
665 |
+
exp.log_html(df.to_html(col_space="80px"))
|
666 |
+
exp.log_parameters(vars(args))
|
667 |
+
exp.log_parameter("eval_path", str(eval_path))
|
668 |
+
exp.add_tag("eval_masker")
|
669 |
+
if args.tags:
|
670 |
+
exp.add_tags(args.tags)
|
671 |
+
exp.log_parameter("model_id", Path(eval_path).name)
|
672 |
+
|
673 |
+
# Close comet
|
674 |
+
exp.end()
|
675 |
+
|
676 |
+
# --------------------------------
|
677 |
+
# ----- END OF MODElS LOOP -----
|
678 |
+
# --------------------------------
|
679 |
+
|
680 |
+
# Compare models
|
681 |
+
if (args.load_metrics or args.write_metrics) and len(evaluations) > 1:
|
682 |
+
print(
|
683 |
+
"Plots for comparing the input models will be created and logged to comet"
|
684 |
+
)
|
685 |
+
|
686 |
+
# Initialize New Comet Experiment
|
687 |
+
exp = Experiment(
|
688 |
+
project_name="climategan-masker-metrics", display_summary_level=0
|
689 |
+
)
|
690 |
+
if args.tags:
|
691 |
+
exp.add_tags(args.tags)
|
692 |
+
|
693 |
+
# Build DataFrame with all models
|
694 |
+
print("Building pandas DataFrame...")
|
695 |
+
models_df = {}
|
696 |
+
for (m, model_path) in enumerate(evaluations):
|
697 |
+
model_path = Path(model_path)
|
698 |
+
with open(model_path / "opts.yaml", "r") as f:
|
699 |
+
opt = yaml.safe_load(f)
|
700 |
+
model_feats = ", ".join(
|
701 |
+
[
|
702 |
+
t
|
703 |
+
for t in sorted(opt["comet"]["tags"])
|
704 |
+
if "branch" not in t and "ablation" not in t and "trash" not in t
|
705 |
+
]
|
706 |
+
)
|
707 |
+
model_id = f"{model_path.parent.name[-2:]}/{model_path.name}"
|
708 |
+
df_m = pd.read_csv(
|
709 |
+
model_path / "eval-metrics" / "eval_masker.csv", index_col=False
|
710 |
+
)
|
711 |
+
df_m["model"] = [model_id] * len(df_m)
|
712 |
+
df_m["model_idx"] = [m] * len(df_m)
|
713 |
+
df_m["model_feats"] = [model_feats] * len(df_m)
|
714 |
+
models_df.update({model_id: df_m})
|
715 |
+
df = pd.concat(list(models_df.values()), ignore_index=True)
|
716 |
+
df["model_img_idx"] = df.model.astype(str) + "-" + df.idx.astype(str)
|
717 |
+
df.rename(columns={"idx": "img_idx"}, inplace=True)
|
718 |
+
dict_models_labels = {
|
719 |
+
k: f"{v['model_idx'][0]}: {v['model_feats'][0]}"
|
720 |
+
for k, v in models_df.items()
|
721 |
+
}
|
722 |
+
print("Done")
|
723 |
+
|
724 |
+
if args.output_csv:
|
725 |
+
print(f"Writing DataFrame to {args.output_csv}")
|
726 |
+
df.to_csv(args.output_csv, index_label="model_img_idx")
|
727 |
+
|
728 |
+
# Determine images with low metrics in any model
|
729 |
+
print("Constructing filter based on metrics thresholds...")
|
730 |
+
idx_not_good_in_any = []
|
731 |
+
for idx in df.img_idx.unique():
|
732 |
+
df_th = df.loc[
|
733 |
+
(
|
734 |
+
# TODO: rethink thresholds
|
735 |
+
(df.tpr <= dict_metrics["threshold"]["tpr"])
|
736 |
+
| (df.fpr >= dict_metrics["threshold"]["fpr"])
|
737 |
+
| (df.edge_coherence >= dict_metrics["threshold"]["edge_coherence"])
|
738 |
+
)
|
739 |
+
& ((df.img_idx == idx) & (df.model.isin(df.model.unique())))
|
740 |
+
]
|
741 |
+
if len(df_th) > 0:
|
742 |
+
idx_not_good_in_any.append(idx)
|
743 |
+
filters = {"all": df.img_idx.unique(), "not_good_in_any": idx_not_good_in_any}
|
744 |
+
print("Done")
|
745 |
+
|
746 |
+
# Boxplots of metrics
|
747 |
+
print("Plotting boxplots of metrics...")
|
748 |
+
for k, f in filters.items():
|
749 |
+
print(f"\tDistribution of [{k}] images...")
|
750 |
+
for metric in dict_metrics["names"].keys():
|
751 |
+
fig_filename = plot_dir / f"boxplot_{metric}_{k}.png"
|
752 |
+
if metric in ["mnr", "mpr", "accuracy_must_may"]:
|
753 |
+
boxplot_metric(
|
754 |
+
fig_filename,
|
755 |
+
df.loc[df.img_idx.isin(f)],
|
756 |
+
metric=metric,
|
757 |
+
dict_metrics=dict_metrics["names"],
|
758 |
+
do_stripplot=True,
|
759 |
+
dict_models=dict_models_labels,
|
760 |
+
order=list(df.model.unique()),
|
761 |
+
)
|
762 |
+
else:
|
763 |
+
boxplot_metric(
|
764 |
+
fig_filename,
|
765 |
+
df.loc[df.img_idx.isin(f)],
|
766 |
+
metric=metric,
|
767 |
+
dict_metrics=dict_metrics["names"],
|
768 |
+
dict_models=dict_models_labels,
|
769 |
+
fliersize=1.0,
|
770 |
+
order=list(df.model.unique()),
|
771 |
+
)
|
772 |
+
exp.log_image(fig_filename)
|
773 |
+
print("Done")
|
774 |
+
|
775 |
+
# Cluster Maps
|
776 |
+
print("Plotting clustermaps...")
|
777 |
+
for k, f in filters.items():
|
778 |
+
print(f"\tDistribution of [{k}] images...")
|
779 |
+
for metric in dict_metrics["names"].keys():
|
780 |
+
fig_filename = plot_dir / f"clustermap_{metric}_{k}.png"
|
781 |
+
df_mf = df.loc[df.img_idx.isin(f)].pivot("img_idx", "model", metric)
|
782 |
+
clustermap_metric(
|
783 |
+
output_filename=fig_filename,
|
784 |
+
df=df_mf,
|
785 |
+
metric=metric,
|
786 |
+
dict_metrics=dict_metrics["names"],
|
787 |
+
method="average",
|
788 |
+
cluster_metric="euclidean",
|
789 |
+
dict_models=dict_models_labels,
|
790 |
+
row_cluster=False,
|
791 |
+
)
|
792 |
+
exp.log_image(fig_filename)
|
793 |
+
print("Done")
|
794 |
+
|
795 |
+
# Close comet
|
796 |
+
exp.end()
|
figures/ablation_comparison.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script evaluates the contribution of a technique from the ablation study for
|
3 |
+
improving the masker evaluation metrics. The differences in the metrics are computed
|
4 |
+
for all images of paired models, that is those which only differ in the inclusion or
|
5 |
+
not of the given technique. Then, statistical inference is performed through the
|
6 |
+
percentile bootstrap to obtain robust estimates of the differences in the metrics and
|
7 |
+
confidence intervals. The script plots the distribution of the bootrstraped estimates.
|
8 |
+
"""
|
9 |
+
print("Imports...", end="")
|
10 |
+
from argparse import ArgumentParser
|
11 |
+
import yaml
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import seaborn as sns
|
15 |
+
import os
|
16 |
+
from pathlib import Path
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import matplotlib.patches as mpatches
|
19 |
+
import matplotlib.transforms as transforms
|
20 |
+
|
21 |
+
|
22 |
+
# -----------------------
|
23 |
+
# ----- Constants -----
|
24 |
+
# -----------------------
|
25 |
+
|
26 |
+
dict_models = {
|
27 |
+
"md": 11,
|
28 |
+
"dada_ms, msd, pseudo": 9,
|
29 |
+
"msd, pseudo": 4,
|
30 |
+
"dada, msd_spade, pseudo": 7,
|
31 |
+
"msd": 13,
|
32 |
+
"dada_m, msd": 17,
|
33 |
+
"dada, msd_spade": 16,
|
34 |
+
"msd_spade, pseudo": 5,
|
35 |
+
"dada_ms, msd": 18,
|
36 |
+
"dada, msd, pseudo": 6,
|
37 |
+
"ms": 12,
|
38 |
+
"dada, msd": 15,
|
39 |
+
"dada_m, msd, pseudo": 8,
|
40 |
+
"msd_spade": 14,
|
41 |
+
"m": 10,
|
42 |
+
"md, pseudo": 2,
|
43 |
+
"ms, pseudo": 3,
|
44 |
+
"m, pseudo": 1,
|
45 |
+
"ground": "G",
|
46 |
+
"instagan": "I",
|
47 |
+
}
|
48 |
+
|
49 |
+
dict_metrics = {
|
50 |
+
"names": {
|
51 |
+
"tpr": "TPR, Recall, Sensitivity",
|
52 |
+
"tnr": "TNR, Specificity, Selectivity",
|
53 |
+
"fpr": "FPR",
|
54 |
+
"fpt": "False positives relative to image size",
|
55 |
+
"fnr": "FNR, Miss rate",
|
56 |
+
"fnt": "False negatives relative to image size",
|
57 |
+
"mpr": "May positive rate (MPR)",
|
58 |
+
"mnr": "May negative rate (MNR)",
|
59 |
+
"accuracy": "Accuracy (ignoring may)",
|
60 |
+
"error": "Error",
|
61 |
+
"f05": "F05 score",
|
62 |
+
"precision": "Precision",
|
63 |
+
"edge_coherence": "Edge coherence",
|
64 |
+
"accuracy_must_may": "Accuracy (ignoring cannot)",
|
65 |
+
},
|
66 |
+
"key_metrics": ["f05", "error", "edge_coherence"],
|
67 |
+
}
|
68 |
+
dict_techniques = {
|
69 |
+
"depth": "depth",
|
70 |
+
"segmentation": "seg",
|
71 |
+
"seg": "seg",
|
72 |
+
"dada_s": "dada_seg",
|
73 |
+
"dada_seg": "dada_seg",
|
74 |
+
"dada_segmentation": "dada_seg",
|
75 |
+
"dada_m": "dada_masker",
|
76 |
+
"dada_masker": "dada_masker",
|
77 |
+
"spade": "spade",
|
78 |
+
"pseudo": "pseudo",
|
79 |
+
"pseudo-labels": "pseudo",
|
80 |
+
"pseudo_labels": "pseudo",
|
81 |
+
}
|
82 |
+
|
83 |
+
# Markers
|
84 |
+
dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"}
|
85 |
+
|
86 |
+
# Model features
|
87 |
+
model_feats = [
|
88 |
+
"masker",
|
89 |
+
"seg",
|
90 |
+
"depth",
|
91 |
+
"dada_seg",
|
92 |
+
"dada_masker",
|
93 |
+
"spade",
|
94 |
+
"pseudo",
|
95 |
+
"ground",
|
96 |
+
"instagan",
|
97 |
+
]
|
98 |
+
|
99 |
+
# Colors
|
100 |
+
palette_colorblind = sns.color_palette("colorblind")
|
101 |
+
color_climategan = palette_colorblind[0]
|
102 |
+
color_munit = palette_colorblind[1]
|
103 |
+
color_cyclegan = palette_colorblind[6]
|
104 |
+
color_instagan = palette_colorblind[8]
|
105 |
+
color_maskinstagan = palette_colorblind[2]
|
106 |
+
color_paintedground = palette_colorblind[3]
|
107 |
+
|
108 |
+
color_cat1 = palette_colorblind[0]
|
109 |
+
color_cat2 = palette_colorblind[1]
|
110 |
+
palette_lightest = [
|
111 |
+
sns.light_palette(color_cat1, n_colors=20)[3],
|
112 |
+
sns.light_palette(color_cat2, n_colors=20)[3],
|
113 |
+
]
|
114 |
+
palette_light = [
|
115 |
+
sns.light_palette(color_cat1, n_colors=3)[1],
|
116 |
+
sns.light_palette(color_cat2, n_colors=3)[1],
|
117 |
+
]
|
118 |
+
palette_medium = [color_cat1, color_cat2]
|
119 |
+
palette_dark = [
|
120 |
+
sns.dark_palette(color_cat1, n_colors=3)[1],
|
121 |
+
sns.dark_palette(color_cat2, n_colors=3)[1],
|
122 |
+
]
|
123 |
+
palette_cat1 = [
|
124 |
+
palette_lightest[0],
|
125 |
+
palette_light[0],
|
126 |
+
palette_medium[0],
|
127 |
+
palette_dark[0],
|
128 |
+
]
|
129 |
+
palette_cat2 = [
|
130 |
+
palette_lightest[1],
|
131 |
+
palette_light[1],
|
132 |
+
palette_medium[1],
|
133 |
+
palette_dark[1],
|
134 |
+
]
|
135 |
+
color_cat1_light = palette_light[0]
|
136 |
+
color_cat2_light = palette_light[1]
|
137 |
+
|
138 |
+
|
139 |
+
def parsed_args():
|
140 |
+
"""
|
141 |
+
Parse and returns command-line args
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
argparse.Namespace: the parsed arguments
|
145 |
+
"""
|
146 |
+
parser = ArgumentParser()
|
147 |
+
parser.add_argument(
|
148 |
+
"--input_csv",
|
149 |
+
default="ablations_metrics_20210311.csv",
|
150 |
+
type=str,
|
151 |
+
help="CSV containing the results of the ablation study",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--output_dir",
|
155 |
+
default=None,
|
156 |
+
type=str,
|
157 |
+
help="Output directory",
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--models",
|
161 |
+
default="all",
|
162 |
+
type=str,
|
163 |
+
help="Models to display: all, pseudo, no_dada_masker, no_baseline",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--dpi",
|
167 |
+
default=200,
|
168 |
+
type=int,
|
169 |
+
help="DPI for the output images",
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--n_bs",
|
173 |
+
default=1e6,
|
174 |
+
type=int,
|
175 |
+
help="Number of bootrstrap samples",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--alpha",
|
179 |
+
default=0.99,
|
180 |
+
type=float,
|
181 |
+
help="Confidence level",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--bs_seed",
|
185 |
+
default=17,
|
186 |
+
type=int,
|
187 |
+
help="Bootstrap random seed, for reproducibility",
|
188 |
+
)
|
189 |
+
|
190 |
+
return parser.parse_args()
|
191 |
+
|
192 |
+
|
193 |
+
def plot_median_metrics(
|
194 |
+
df, do_stripplot=True, dpi=200, bs_seed=37, n_bs=1000, **snskwargs
|
195 |
+
):
|
196 |
+
def plot_metric(
|
197 |
+
ax, df, metric, do_stripplot=True, dpi=200, bs_seed=37, marker="o", **snskwargs
|
198 |
+
):
|
199 |
+
|
200 |
+
y_labels = [dict_models[f] for f in df.model_feats.unique()]
|
201 |
+
|
202 |
+
# Labels
|
203 |
+
y_labels_int = np.sort([el for el in y_labels if isinstance(el, int)]).tolist()
|
204 |
+
y_order_int = [
|
205 |
+
k for vs in y_labels_int for k, vu in dict_models.items() if vs == vu
|
206 |
+
]
|
207 |
+
y_labels_int = [str(el) for el in y_labels_int]
|
208 |
+
|
209 |
+
y_labels_str = sorted([el for el in y_labels if not isinstance(el, int)])
|
210 |
+
y_order_str = [
|
211 |
+
k for vs in y_labels_str for k, vu in dict_models.items() if vs == vu
|
212 |
+
]
|
213 |
+
y_labels = y_labels_int + y_labels_str
|
214 |
+
y_order = y_order_int + y_order_str
|
215 |
+
|
216 |
+
# Palette
|
217 |
+
palette = len(y_labels_int) * [color_climategan]
|
218 |
+
for y in y_labels_str:
|
219 |
+
if y == "G":
|
220 |
+
palette = palette + [color_paintedground]
|
221 |
+
if y == "I":
|
222 |
+
palette = palette + [color_maskinstagan]
|
223 |
+
|
224 |
+
# Error
|
225 |
+
sns.pointplot(
|
226 |
+
ax=ax,
|
227 |
+
data=df,
|
228 |
+
x=metric,
|
229 |
+
y="model_feats",
|
230 |
+
order=y_order,
|
231 |
+
markers=marker,
|
232 |
+
estimator=np.median,
|
233 |
+
ci=99,
|
234 |
+
seed=bs_seed,
|
235 |
+
n_boot=n_bs,
|
236 |
+
join=False,
|
237 |
+
scale=0.6,
|
238 |
+
errwidth=1.5,
|
239 |
+
capsize=0.1,
|
240 |
+
palette=palette,
|
241 |
+
)
|
242 |
+
xlim = ax.get_xlim()
|
243 |
+
|
244 |
+
if do_stripplot:
|
245 |
+
sns.stripplot(
|
246 |
+
ax=ax,
|
247 |
+
data=df,
|
248 |
+
x=metric,
|
249 |
+
y="model_feats",
|
250 |
+
size=1.5,
|
251 |
+
palette=palette,
|
252 |
+
alpha=0.2,
|
253 |
+
)
|
254 |
+
ax.set_xlim(xlim)
|
255 |
+
|
256 |
+
# Set X-label
|
257 |
+
ax.set_xlabel(dict_metrics["names"][metric], rotation=0, fontsize="medium")
|
258 |
+
|
259 |
+
# Set Y-label
|
260 |
+
ax.set_ylabel(None)
|
261 |
+
|
262 |
+
ax.set_yticklabels(y_labels, fontsize="medium")
|
263 |
+
|
264 |
+
# Change spines
|
265 |
+
sns.despine(ax=ax, left=True, bottom=True)
|
266 |
+
|
267 |
+
# Draw gray area on final model
|
268 |
+
xlim = ax.get_xlim()
|
269 |
+
ylim = ax.get_ylim()
|
270 |
+
trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
|
271 |
+
rect = mpatches.Rectangle(
|
272 |
+
xy=(0.0, 5.5),
|
273 |
+
width=1,
|
274 |
+
height=1,
|
275 |
+
transform=trans,
|
276 |
+
linewidth=0.0,
|
277 |
+
edgecolor="none",
|
278 |
+
facecolor="gray",
|
279 |
+
alpha=0.05,
|
280 |
+
)
|
281 |
+
ax.add_patch(rect)
|
282 |
+
|
283 |
+
# Set up plot
|
284 |
+
sns.set(style="whitegrid")
|
285 |
+
plt.rcParams.update({"font.family": "serif"})
|
286 |
+
plt.rcParams.update(
|
287 |
+
{
|
288 |
+
"font.serif": [
|
289 |
+
"Computer Modern Roman",
|
290 |
+
"Times New Roman",
|
291 |
+
"Utopia",
|
292 |
+
"New Century Schoolbook",
|
293 |
+
"Century Schoolbook L",
|
294 |
+
"ITC Bookman",
|
295 |
+
"Bookman",
|
296 |
+
"Times",
|
297 |
+
"Palatino",
|
298 |
+
"Charter",
|
299 |
+
"serif" "Bitstream Vera Serif",
|
300 |
+
"DejaVu Serif",
|
301 |
+
]
|
302 |
+
}
|
303 |
+
)
|
304 |
+
|
305 |
+
fig_h = 0.4 * len(df.model_feats.unique())
|
306 |
+
fig, axes = plt.subplots(
|
307 |
+
nrows=1, ncols=3, sharey=True, dpi=dpi, figsize=(18, fig_h)
|
308 |
+
)
|
309 |
+
|
310 |
+
# Error
|
311 |
+
plot_metric(
|
312 |
+
axes[0],
|
313 |
+
df,
|
314 |
+
"error",
|
315 |
+
do_stripplot=do_stripplot,
|
316 |
+
dpi=dpi,
|
317 |
+
bs_seed=bs_seed,
|
318 |
+
marker=dict_markers["error"],
|
319 |
+
)
|
320 |
+
axes[0].set_ylabel("Models")
|
321 |
+
|
322 |
+
# F05
|
323 |
+
plot_metric(
|
324 |
+
axes[1],
|
325 |
+
df,
|
326 |
+
"f05",
|
327 |
+
do_stripplot=do_stripplot,
|
328 |
+
dpi=dpi,
|
329 |
+
bs_seed=bs_seed,
|
330 |
+
marker=dict_markers["f05"],
|
331 |
+
)
|
332 |
+
|
333 |
+
# Edge coherence
|
334 |
+
plot_metric(
|
335 |
+
axes[2],
|
336 |
+
df,
|
337 |
+
"edge_coherence",
|
338 |
+
do_stripplot=do_stripplot,
|
339 |
+
dpi=dpi,
|
340 |
+
bs_seed=bs_seed,
|
341 |
+
marker=dict_markers["edge_coherence"],
|
342 |
+
)
|
343 |
+
xticks = axes[2].get_xticks()
|
344 |
+
xticklabels = ["{:.3f}".format(x) for x in xticks]
|
345 |
+
axes[2].set(xticks=xticks, xticklabels=xticklabels)
|
346 |
+
|
347 |
+
plt.subplots_adjust(wspace=0.12)
|
348 |
+
|
349 |
+
return fig
|
350 |
+
|
351 |
+
|
352 |
+
if __name__ == "__main__":
|
353 |
+
# -----------------------------
|
354 |
+
# ----- Parse arguments -----
|
355 |
+
# -----------------------------
|
356 |
+
args = parsed_args()
|
357 |
+
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
|
358 |
+
|
359 |
+
# Determine output dir
|
360 |
+
if args.output_dir is None:
|
361 |
+
output_dir = Path(os.environ["SLURM_TMPDIR"])
|
362 |
+
else:
|
363 |
+
output_dir = Path(args.output_dir)
|
364 |
+
if not output_dir.exists():
|
365 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
366 |
+
|
367 |
+
# Store args
|
368 |
+
output_yml = output_dir / "ablation_comparison_{}.yml".format(args.models)
|
369 |
+
with open(output_yml, "w") as f:
|
370 |
+
yaml.dump(vars(args), f)
|
371 |
+
|
372 |
+
# Read CSV
|
373 |
+
df = pd.read_csv(args.input_csv, index_col="model_img_idx")
|
374 |
+
|
375 |
+
# Determine models
|
376 |
+
if "all" in args.models.lower():
|
377 |
+
pass
|
378 |
+
else:
|
379 |
+
if "no_baseline" in args.models.lower():
|
380 |
+
df = df.loc[(df.ground == False) & (df.instagan == False)]
|
381 |
+
if "pseudo" in args.models.lower():
|
382 |
+
df = df.loc[
|
383 |
+
(df.pseudo == True) | (df.ground == True) | (df.instagan == True)
|
384 |
+
]
|
385 |
+
if "no_dada_mask" in args.models.lower():
|
386 |
+
df = df.loc[
|
387 |
+
(df.dada_masker == False) | (df.ground == True) | (df.instagan == True)
|
388 |
+
]
|
389 |
+
|
390 |
+
fig = plot_median_metrics(df, do_stripplot=True, dpi=args.dpi, bs_seed=args.bs_seed)
|
391 |
+
|
392 |
+
# Save figure
|
393 |
+
output_fig = output_dir / "ablation_comparison_{}.png".format(args.models)
|
394 |
+
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
|
figures/bootstrap_ablation.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script evaluates the contribution of a technique from the ablation study for
|
3 |
+
improving the masker evaluation metrics. The differences in the metrics are computed
|
4 |
+
for all images of paired models, that is those which only differ in the inclusion or
|
5 |
+
not of the given technique. Then, statistical inference is performed through the
|
6 |
+
percentile bootstrap to obtain robust estimates of the differences in the metrics and
|
7 |
+
confidence intervals. The script plots the distribution of the bootrstraped estimates.
|
8 |
+
"""
|
9 |
+
print("Imports...", end="")
|
10 |
+
from argparse import ArgumentParser
|
11 |
+
import yaml
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
import pandas as pd
|
15 |
+
import seaborn as sns
|
16 |
+
from scipy.stats import trim_mean
|
17 |
+
from tqdm import tqdm
|
18 |
+
from pathlib import Path
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
import matplotlib.patches as mpatches
|
21 |
+
|
22 |
+
|
23 |
+
# -----------------------
|
24 |
+
# ----- Constants -----
|
25 |
+
# -----------------------
|
26 |
+
|
27 |
+
dict_metrics = {
|
28 |
+
"names": {
|
29 |
+
"tpr": "TPR, Recall, Sensitivity",
|
30 |
+
"tnr": "TNR, Specificity, Selectivity",
|
31 |
+
"fpr": "FPR",
|
32 |
+
"fpt": "False positives relative to image size",
|
33 |
+
"fnr": "FNR, Miss rate",
|
34 |
+
"fnt": "False negatives relative to image size",
|
35 |
+
"mpr": "May positive rate (MPR)",
|
36 |
+
"mnr": "May negative rate (MNR)",
|
37 |
+
"accuracy": "Accuracy (ignoring may)",
|
38 |
+
"error": "Error",
|
39 |
+
"f05": "F05 score",
|
40 |
+
"precision": "Precision",
|
41 |
+
"edge_coherence": "Edge coherence",
|
42 |
+
"accuracy_must_may": "Accuracy (ignoring cannot)",
|
43 |
+
},
|
44 |
+
"key_metrics": ["f05", "error", "edge_coherence"],
|
45 |
+
}
|
46 |
+
dict_techniques = {
|
47 |
+
"depth": "depth",
|
48 |
+
"segmentation": "seg",
|
49 |
+
"seg": "seg",
|
50 |
+
"dada_s": "dada_seg",
|
51 |
+
"dada_seg": "dada_seg",
|
52 |
+
"dada_segmentation": "dada_seg",
|
53 |
+
"dada_m": "dada_masker",
|
54 |
+
"dada_masker": "dada_masker",
|
55 |
+
"spade": "spade",
|
56 |
+
"pseudo": "pseudo",
|
57 |
+
"pseudo-labels": "pseudo",
|
58 |
+
"pseudo_labels": "pseudo",
|
59 |
+
}
|
60 |
+
|
61 |
+
# Model features
|
62 |
+
model_feats = [
|
63 |
+
"masker",
|
64 |
+
"seg",
|
65 |
+
"depth",
|
66 |
+
"dada_seg",
|
67 |
+
"dada_masker",
|
68 |
+
"spade",
|
69 |
+
"pseudo",
|
70 |
+
"ground",
|
71 |
+
"instagan",
|
72 |
+
]
|
73 |
+
|
74 |
+
# Colors
|
75 |
+
palette_colorblind = sns.color_palette("colorblind")
|
76 |
+
color_cat1 = palette_colorblind[0]
|
77 |
+
color_cat2 = palette_colorblind[1]
|
78 |
+
palette_lightest = [
|
79 |
+
sns.light_palette(color_cat1, n_colors=20)[3],
|
80 |
+
sns.light_palette(color_cat2, n_colors=20)[3],
|
81 |
+
]
|
82 |
+
palette_light = [
|
83 |
+
sns.light_palette(color_cat1, n_colors=3)[1],
|
84 |
+
sns.light_palette(color_cat2, n_colors=3)[1],
|
85 |
+
]
|
86 |
+
palette_medium = [color_cat1, color_cat2]
|
87 |
+
palette_dark = [
|
88 |
+
sns.dark_palette(color_cat1, n_colors=3)[1],
|
89 |
+
sns.dark_palette(color_cat2, n_colors=3)[1],
|
90 |
+
]
|
91 |
+
palette_cat1 = [
|
92 |
+
palette_lightest[0],
|
93 |
+
palette_light[0],
|
94 |
+
palette_medium[0],
|
95 |
+
palette_dark[0],
|
96 |
+
]
|
97 |
+
palette_cat2 = [
|
98 |
+
palette_lightest[1],
|
99 |
+
palette_light[1],
|
100 |
+
palette_medium[1],
|
101 |
+
palette_dark[1],
|
102 |
+
]
|
103 |
+
color_cat1_light = palette_light[0]
|
104 |
+
color_cat2_light = palette_light[1]
|
105 |
+
|
106 |
+
|
107 |
+
def parsed_args():
|
108 |
+
"""
|
109 |
+
Parse and returns command-line args
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
argparse.Namespace: the parsed arguments
|
113 |
+
"""
|
114 |
+
parser = ArgumentParser()
|
115 |
+
parser.add_argument(
|
116 |
+
"--input_csv",
|
117 |
+
default="ablations_metrics_20210311.csv",
|
118 |
+
type=str,
|
119 |
+
help="CSV containing the results of the ablation study",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--output_dir",
|
123 |
+
default=None,
|
124 |
+
type=str,
|
125 |
+
help="Output directory",
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--technique",
|
129 |
+
default=None,
|
130 |
+
type=str,
|
131 |
+
help="Keyword specifying the technique. One of: pseudo, depth, segmentation, dada_seg, dada_masker, spade",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--dpi",
|
135 |
+
default=200,
|
136 |
+
type=int,
|
137 |
+
help="DPI for the output images",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--n_bs",
|
141 |
+
default=1e6,
|
142 |
+
type=int,
|
143 |
+
help="Number of bootrstrap samples",
|
144 |
+
)
|
145 |
+
parser.add_argument(
|
146 |
+
"--alpha",
|
147 |
+
default=0.99,
|
148 |
+
type=float,
|
149 |
+
help="Confidence level",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--bs_seed",
|
153 |
+
default=17,
|
154 |
+
type=int,
|
155 |
+
help="Bootstrap random seed, for reproducibility",
|
156 |
+
)
|
157 |
+
|
158 |
+
return parser.parse_args()
|
159 |
+
|
160 |
+
|
161 |
+
def add_ci_mean(
|
162 |
+
ax, sample_measure, bs_mean, bs_std, ci, color, alpha, fontsize, invert=False
|
163 |
+
):
|
164 |
+
|
165 |
+
# Fill area between CI
|
166 |
+
dist = ax.lines[0]
|
167 |
+
dist_y = dist.get_ydata()
|
168 |
+
dist_x = dist.get_xdata()
|
169 |
+
linewidth = dist.get_linewidth()
|
170 |
+
|
171 |
+
x_idx_low = np.argmin(np.abs(dist_x - ci[0]))
|
172 |
+
x_idx_high = np.argmin(np.abs(dist_x - ci[1]))
|
173 |
+
x_ci = dist_x[x_idx_low:x_idx_high]
|
174 |
+
y_ci = dist_y[x_idx_low:x_idx_high]
|
175 |
+
|
176 |
+
ax.fill_between(x_ci, 0, y_ci, facecolor=color, alpha=alpha)
|
177 |
+
|
178 |
+
# Add vertical lines of CI
|
179 |
+
ax.vlines(
|
180 |
+
x=ci[0],
|
181 |
+
ymin=0.0,
|
182 |
+
ymax=y_ci[0],
|
183 |
+
color=color,
|
184 |
+
linewidth=linewidth,
|
185 |
+
label="ci_low",
|
186 |
+
)
|
187 |
+
ax.vlines(
|
188 |
+
x=ci[1],
|
189 |
+
ymin=0.0,
|
190 |
+
ymax=y_ci[-1],
|
191 |
+
color=color,
|
192 |
+
linewidth=linewidth,
|
193 |
+
label="ci_high",
|
194 |
+
)
|
195 |
+
|
196 |
+
# Add annotations
|
197 |
+
bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2)
|
198 |
+
|
199 |
+
if invert:
|
200 |
+
ha_l = "right"
|
201 |
+
ha_u = "left"
|
202 |
+
else:
|
203 |
+
ha_l = "left"
|
204 |
+
ha_u = "right"
|
205 |
+
ax.text(
|
206 |
+
ci[0],
|
207 |
+
0.0,
|
208 |
+
s="L = {:.4f}".format(ci[0]),
|
209 |
+
ha=ha_l,
|
210 |
+
va="bottom",
|
211 |
+
fontsize=fontsize,
|
212 |
+
bbox=bbox_props,
|
213 |
+
)
|
214 |
+
ax.text(
|
215 |
+
ci[1],
|
216 |
+
0.0,
|
217 |
+
s="U = {:.4f}".format(ci[1]),
|
218 |
+
ha=ha_u,
|
219 |
+
va="bottom",
|
220 |
+
fontsize=fontsize,
|
221 |
+
bbox=bbox_props,
|
222 |
+
)
|
223 |
+
|
224 |
+
# Add vertical line of bootstrap mean
|
225 |
+
x_idx_mean = np.argmin(np.abs(dist_x - bs_mean))
|
226 |
+
ax.vlines(
|
227 |
+
x=bs_mean, ymin=0.0, ymax=dist_y[x_idx_mean], color="k", linewidth=linewidth
|
228 |
+
)
|
229 |
+
|
230 |
+
# Add annotation of bootstrap mean
|
231 |
+
bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2)
|
232 |
+
|
233 |
+
ax.text(
|
234 |
+
bs_mean,
|
235 |
+
0.6 * dist_y[x_idx_mean],
|
236 |
+
s="Bootstrap mean = {:.4f}".format(bs_mean),
|
237 |
+
ha="center",
|
238 |
+
va="center",
|
239 |
+
fontsize=fontsize,
|
240 |
+
bbox=bbox_props,
|
241 |
+
)
|
242 |
+
|
243 |
+
# Add vertical line of sample_measure
|
244 |
+
x_idx_smeas = np.argmin(np.abs(dist_x - sample_measure))
|
245 |
+
ax.vlines(
|
246 |
+
x=sample_measure,
|
247 |
+
ymin=0.0,
|
248 |
+
ymax=dist_y[x_idx_smeas],
|
249 |
+
color="k",
|
250 |
+
linewidth=linewidth,
|
251 |
+
linestyles="dotted",
|
252 |
+
)
|
253 |
+
|
254 |
+
# Add SD
|
255 |
+
bbox_props = dict(boxstyle="darrow, pad=0.4", fc="w", ec="k", lw=2)
|
256 |
+
|
257 |
+
ax.text(
|
258 |
+
bs_mean,
|
259 |
+
0.4 * dist_y[x_idx_mean],
|
260 |
+
s="SD = {:.4f} = SE".format(bs_std),
|
261 |
+
ha="center",
|
262 |
+
va="center",
|
263 |
+
fontsize=fontsize,
|
264 |
+
bbox=bbox_props,
|
265 |
+
)
|
266 |
+
|
267 |
+
|
268 |
+
def add_null_pval(ax, null, color, alpha, fontsize):
|
269 |
+
|
270 |
+
# Fill area between CI
|
271 |
+
dist = ax.lines[0]
|
272 |
+
dist_y = dist.get_ydata()
|
273 |
+
dist_x = dist.get_xdata()
|
274 |
+
linewidth = dist.get_linewidth()
|
275 |
+
|
276 |
+
x_idx_null = np.argmin(np.abs(dist_x - null))
|
277 |
+
if x_idx_null >= (len(dist_x) / 2.0):
|
278 |
+
x_pval = dist_x[x_idx_null:]
|
279 |
+
y_pval = dist_y[x_idx_null:]
|
280 |
+
else:
|
281 |
+
x_pval = dist_x[:x_idx_null]
|
282 |
+
y_pval = dist_y[:x_idx_null]
|
283 |
+
|
284 |
+
ax.fill_between(x_pval, 0, y_pval, facecolor=color, alpha=alpha)
|
285 |
+
|
286 |
+
# Add vertical lines of null
|
287 |
+
dist = ax.lines[0]
|
288 |
+
linewidth = dist.get_linewidth()
|
289 |
+
y_max = ax.get_ylim()[1]
|
290 |
+
ax.vlines(
|
291 |
+
x=null,
|
292 |
+
ymin=0.0,
|
293 |
+
ymax=y_max,
|
294 |
+
color="k",
|
295 |
+
linewidth=linewidth,
|
296 |
+
linestyles="dotted",
|
297 |
+
)
|
298 |
+
|
299 |
+
# Add annotations
|
300 |
+
bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2)
|
301 |
+
|
302 |
+
ax.text(
|
303 |
+
null,
|
304 |
+
0.75 * y_max,
|
305 |
+
s="Null hypothesis = {:.1f}".format(null),
|
306 |
+
ha="center",
|
307 |
+
va="center",
|
308 |
+
fontsize=fontsize,
|
309 |
+
bbox=bbox_props,
|
310 |
+
)
|
311 |
+
|
312 |
+
|
313 |
+
def plot_bootstrap_distr(
|
314 |
+
sample_measure, bs_samples, alpha, color_ci, color_pval=None, null=None
|
315 |
+
):
|
316 |
+
|
317 |
+
# Compute results from bootstrap
|
318 |
+
q_low = (1.0 - alpha) / 2.0
|
319 |
+
q_high = 1.0 - q_low
|
320 |
+
ci = np.quantile(bs_samples, [q_low, q_high])
|
321 |
+
bs_mean = np.mean(bs_samples)
|
322 |
+
bs_std = np.std(bs_samples)
|
323 |
+
|
324 |
+
if null is not None and color_pval is not None:
|
325 |
+
pval_flag = True
|
326 |
+
pval = np.min([[np.mean(bs_samples > null), np.mean(bs_samples < null)]]) * 2
|
327 |
+
else:
|
328 |
+
pval_flag = False
|
329 |
+
|
330 |
+
# Set up plot
|
331 |
+
sns.set(style="whitegrid")
|
332 |
+
fontsize = 24
|
333 |
+
font = {"family": "DejaVu Sans", "weight": "normal", "size": fontsize}
|
334 |
+
plt.rc("font", **font)
|
335 |
+
alpha_plot = 0.5
|
336 |
+
|
337 |
+
# Initialize the matplotlib figure
|
338 |
+
fig, ax = plt.subplots(figsize=(30, 12), dpi=args.dpi)
|
339 |
+
|
340 |
+
# Plot distribution of bootstrap means
|
341 |
+
sns.kdeplot(bs_samples, color="b", linewidth=5, gridsize=1000, ax=ax)
|
342 |
+
|
343 |
+
y_lim = ax.get_ylim()
|
344 |
+
|
345 |
+
# Change spines
|
346 |
+
sns.despine(left=True, bottom=True)
|
347 |
+
|
348 |
+
# Annotations
|
349 |
+
add_ci_mean(
|
350 |
+
ax,
|
351 |
+
sample_measure,
|
352 |
+
bs_mean,
|
353 |
+
bs_std,
|
354 |
+
ci,
|
355 |
+
color=color_ci,
|
356 |
+
alpha=alpha_plot,
|
357 |
+
fontsize=fontsize,
|
358 |
+
)
|
359 |
+
|
360 |
+
if pval_flag:
|
361 |
+
add_null_pval(ax, null, color=color_pval, alpha=alpha_plot, fontsize=fontsize)
|
362 |
+
|
363 |
+
# Legend
|
364 |
+
ci_patch = mpatches.Patch(
|
365 |
+
facecolor=color_ci,
|
366 |
+
edgecolor=None,
|
367 |
+
alpha=alpha_plot,
|
368 |
+
label="{:d} % confidence interval".format(int(100 * alpha)),
|
369 |
+
)
|
370 |
+
|
371 |
+
if pval_flag:
|
372 |
+
if pval == 0.0:
|
373 |
+
pval_patch = mpatches.Patch(
|
374 |
+
facecolor=color_pval,
|
375 |
+
edgecolor=None,
|
376 |
+
alpha=alpha_plot,
|
377 |
+
label="P value / 2 = {:.1f}".format(pval / 2.0),
|
378 |
+
)
|
379 |
+
elif np.around(pval / 2.0, decimals=4) > 0.0000:
|
380 |
+
pval_patch = mpatches.Patch(
|
381 |
+
facecolor=color_pval,
|
382 |
+
edgecolor=None,
|
383 |
+
alpha=alpha_plot,
|
384 |
+
label="P value / 2 = {:.4f}".format(pval / 2.0),
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
pval_patch = mpatches.Patch(
|
388 |
+
facecolor=color_pval,
|
389 |
+
edgecolor=None,
|
390 |
+
alpha=alpha_plot,
|
391 |
+
label="P value / 2 < $10^{}$".format(np.ceil(np.log10(pval / 2.0))),
|
392 |
+
)
|
393 |
+
|
394 |
+
leg = ax.legend(
|
395 |
+
handles=[ci_patch, pval_patch],
|
396 |
+
ncol=1,
|
397 |
+
loc="upper right",
|
398 |
+
frameon=True,
|
399 |
+
framealpha=1.0,
|
400 |
+
title="",
|
401 |
+
fontsize=fontsize,
|
402 |
+
columnspacing=1.0,
|
403 |
+
labelspacing=0.2,
|
404 |
+
markerfirst=True,
|
405 |
+
)
|
406 |
+
else:
|
407 |
+
leg = ax.legend(
|
408 |
+
handles=[ci_patch],
|
409 |
+
ncol=1,
|
410 |
+
loc="upper right",
|
411 |
+
frameon=True,
|
412 |
+
framealpha=1.0,
|
413 |
+
title="",
|
414 |
+
fontsize=fontsize,
|
415 |
+
columnspacing=1.0,
|
416 |
+
labelspacing=0.2,
|
417 |
+
markerfirst=True,
|
418 |
+
)
|
419 |
+
|
420 |
+
plt.setp(leg.get_title(), fontsize=fontsize, horizontalalignment="left")
|
421 |
+
|
422 |
+
# Set X-label
|
423 |
+
ax.set_xlabel("Bootstrap estimates", rotation=0, fontsize=fontsize, labelpad=10.0)
|
424 |
+
|
425 |
+
# Set Y-label
|
426 |
+
ax.set_ylabel("Density", rotation=90, fontsize=fontsize, labelpad=10.0)
|
427 |
+
|
428 |
+
# Ticks
|
429 |
+
plt.setp(ax.get_xticklabels(), fontsize=0.8 * fontsize, verticalalignment="top")
|
430 |
+
plt.setp(ax.get_yticklabels(), fontsize=0.8 * fontsize)
|
431 |
+
|
432 |
+
ax.set_ylim(y_lim)
|
433 |
+
|
434 |
+
return fig, bs_mean, bs_std, ci, pval
|
435 |
+
|
436 |
+
|
437 |
+
if __name__ == "__main__":
|
438 |
+
# -----------------------------
|
439 |
+
# ----- Parse arguments -----
|
440 |
+
# -----------------------------
|
441 |
+
args = parsed_args()
|
442 |
+
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
|
443 |
+
|
444 |
+
# Determine output dir
|
445 |
+
if args.output_dir is None:
|
446 |
+
output_dir = Path(os.environ["SLURM_TMPDIR"])
|
447 |
+
else:
|
448 |
+
output_dir = Path(args.output_dir)
|
449 |
+
if not output_dir.exists():
|
450 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
451 |
+
|
452 |
+
# Store args
|
453 |
+
output_yml = output_dir / "{}_bootstrap.yml".format(args.technique)
|
454 |
+
with open(output_yml, "w") as f:
|
455 |
+
yaml.dump(vars(args), f)
|
456 |
+
|
457 |
+
# Determine technique
|
458 |
+
if args.technique.lower() not in dict_techniques:
|
459 |
+
raise ValueError("{} is not a valid technique".format(args.technique))
|
460 |
+
else:
|
461 |
+
technique = dict_techniques[args.technique.lower()]
|
462 |
+
|
463 |
+
# Read CSV
|
464 |
+
df = pd.read_csv(args.input_csv, index_col="model_img_idx")
|
465 |
+
|
466 |
+
# Find relevant model pairs
|
467 |
+
model_pairs = []
|
468 |
+
for mi in df.loc[df[technique]].model_feats.unique():
|
469 |
+
for mj in df.model_feats.unique():
|
470 |
+
if mj == mi:
|
471 |
+
continue
|
472 |
+
|
473 |
+
if df.loc[df.model_feats == mj, technique].unique()[0]:
|
474 |
+
continue
|
475 |
+
|
476 |
+
is_pair = True
|
477 |
+
for f in model_feats:
|
478 |
+
if f == technique:
|
479 |
+
continue
|
480 |
+
elif (
|
481 |
+
df.loc[df.model_feats == mj, f].unique()[0]
|
482 |
+
!= df.loc[df.model_feats == mi, f].unique()[0]
|
483 |
+
):
|
484 |
+
is_pair = False
|
485 |
+
break
|
486 |
+
else:
|
487 |
+
pass
|
488 |
+
if is_pair:
|
489 |
+
model_pairs.append((mi, mj))
|
490 |
+
break
|
491 |
+
|
492 |
+
print("\nModel pairs identified:\n")
|
493 |
+
for pair in model_pairs:
|
494 |
+
print("{} & {}".format(pair[0], pair[1]))
|
495 |
+
|
496 |
+
df["base"] = ["N/A"] * len(df)
|
497 |
+
for spp in model_pairs:
|
498 |
+
df.loc[df.model_feats.isin(spp), "depth_base"] = spp[1]
|
499 |
+
|
500 |
+
# Build bootstrap data
|
501 |
+
data = {m: [] for m in dict_metrics["key_metrics"]}
|
502 |
+
for m_with, m_without in model_pairs:
|
503 |
+
df_with = df.loc[df.model_feats == m_with]
|
504 |
+
df_without = df.loc[df.model_feats == m_without]
|
505 |
+
for metric in data.keys():
|
506 |
+
diff = (
|
507 |
+
df_with.sort_values(by="img_idx")[metric].values
|
508 |
+
- df_without.sort_values(by="img_idx")[metric].values
|
509 |
+
)
|
510 |
+
data[metric].extend(diff.tolist())
|
511 |
+
|
512 |
+
# Run bootstrap
|
513 |
+
measures = ["mean", "median", "20_trimmed_mean"]
|
514 |
+
bs_data = {meas: {m: np.zeros(args.n_bs) for m in data.keys()} for meas in measures}
|
515 |
+
|
516 |
+
np.random.seed(args.bs_seed)
|
517 |
+
for m, data_m in data.items():
|
518 |
+
for idx, s in enumerate(tqdm(range(args.n_bs))):
|
519 |
+
# Sample with replacement
|
520 |
+
bs_sample = np.random.choice(data_m, size=len(data_m), replace=True)
|
521 |
+
|
522 |
+
# Store mean
|
523 |
+
bs_data["mean"][m][idx] = np.mean(bs_sample)
|
524 |
+
|
525 |
+
# Store median
|
526 |
+
bs_data["median"][m][idx] = np.median(bs_sample)
|
527 |
+
|
528 |
+
# Store 20 % trimmed mean
|
529 |
+
bs_data["20_trimmed_mean"][m][idx] = trim_mean(bs_sample, 0.2)
|
530 |
+
|
531 |
+
for metric in dict_metrics["key_metrics"]:
|
532 |
+
sample_measure = trim_mean(data[metric], 0.2)
|
533 |
+
fig, bs_mean, bs_std, ci, pval = plot_bootstrap_distr(
|
534 |
+
sample_measure,
|
535 |
+
bs_data["20_trimmed_mean"][metric],
|
536 |
+
alpha=args.alpha,
|
537 |
+
color_ci=color_cat1_light,
|
538 |
+
color_pval=color_cat2_light,
|
539 |
+
null=0.0,
|
540 |
+
)
|
541 |
+
|
542 |
+
# Save figure
|
543 |
+
output_fig = output_dir / "{}_bootstrap_{}_{}.png".format(
|
544 |
+
args.technique, metric, "20_trimmed_mean"
|
545 |
+
)
|
546 |
+
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
|
547 |
+
|
548 |
+
# Store results
|
549 |
+
output_results = output_dir / "{}_bootstrap_{}_{}.yml".format(
|
550 |
+
args.technique, metric, "20_trimmed_mean"
|
551 |
+
)
|
552 |
+
results_dict = {
|
553 |
+
"measure": "20_trimmed_mean",
|
554 |
+
"sample_measure": float(sample_measure),
|
555 |
+
"bs_mean": float(bs_mean),
|
556 |
+
"bs_std": float(bs_std),
|
557 |
+
"ci_left": float(ci[0]),
|
558 |
+
"ci_right": float(ci[1]),
|
559 |
+
"pval": float(pval),
|
560 |
+
}
|
561 |
+
with open(output_results, "w") as f:
|
562 |
+
yaml.dump(results_dict, f)
|
figures/bootstrap_ablation_summary.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script computes the median difference and confidence intervals of all techniques from the ablation study for
|
3 |
+
improving the masker evaluation metrics. The differences in the metrics are computed
|
4 |
+
for all images of paired models, that is those which only differ in the inclusion or
|
5 |
+
not of the given technique. Then, statistical inference is performed through the
|
6 |
+
percentile bootstrap to obtain robust estimates of the differences in the metrics and
|
7 |
+
confidence intervals. The script plots the summary for all techniques.
|
8 |
+
"""
|
9 |
+
print("Imports...", end="")
|
10 |
+
from argparse import ArgumentParser
|
11 |
+
import yaml
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import seaborn as sns
|
15 |
+
from scipy.special import comb
|
16 |
+
from scipy.stats import trim_mean
|
17 |
+
from tqdm import tqdm
|
18 |
+
from collections import OrderedDict
|
19 |
+
from pathlib import Path
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
import matplotlib.patches as mpatches
|
22 |
+
import matplotlib.transforms as transforms
|
23 |
+
|
24 |
+
|
25 |
+
# -----------------------
|
26 |
+
# ----- Constants -----
|
27 |
+
# -----------------------
|
28 |
+
|
29 |
+
dict_metrics = {
|
30 |
+
"names": {
|
31 |
+
"tpr": "TPR, Recall, Sensitivity",
|
32 |
+
"tnr": "TNR, Specificity, Selectivity",
|
33 |
+
"fpr": "FPR",
|
34 |
+
"fpt": "False positives relative to image size",
|
35 |
+
"fnr": "FNR, Miss rate",
|
36 |
+
"fnt": "False negatives relative to image size",
|
37 |
+
"mpr": "May positive rate (MPR)",
|
38 |
+
"mnr": "May negative rate (MNR)",
|
39 |
+
"accuracy": "Accuracy (ignoring may)",
|
40 |
+
"error": "Error",
|
41 |
+
"f05": "F05 score",
|
42 |
+
"precision": "Precision",
|
43 |
+
"edge_coherence": "Edge coherence",
|
44 |
+
"accuracy_must_may": "Accuracy (ignoring cannot)",
|
45 |
+
},
|
46 |
+
"key_metrics": ["error", "f05", "edge_coherence"],
|
47 |
+
}
|
48 |
+
|
49 |
+
dict_techniques = OrderedDict(
|
50 |
+
[
|
51 |
+
("pseudo", "Pseudo labels"),
|
52 |
+
("depth", "Depth (D)"),
|
53 |
+
("seg", "Seg. (S)"),
|
54 |
+
("spade", "SPADE"),
|
55 |
+
("dada_seg", "DADA (S)"),
|
56 |
+
("dada_masker", "DADA (M)"),
|
57 |
+
]
|
58 |
+
)
|
59 |
+
|
60 |
+
# Model features
|
61 |
+
model_feats = [
|
62 |
+
"masker",
|
63 |
+
"seg",
|
64 |
+
"depth",
|
65 |
+
"dada_seg",
|
66 |
+
"dada_masker",
|
67 |
+
"spade",
|
68 |
+
"pseudo",
|
69 |
+
"ground",
|
70 |
+
"instagan",
|
71 |
+
]
|
72 |
+
|
73 |
+
# Colors
|
74 |
+
crest = sns.color_palette("crest", as_cmap=False, n_colors=7)
|
75 |
+
palette_metrics = [crest[0], crest[3], crest[6]]
|
76 |
+
sns.palplot(palette_metrics)
|
77 |
+
|
78 |
+
# Markers
|
79 |
+
dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"}
|
80 |
+
|
81 |
+
|
82 |
+
def parsed_args():
|
83 |
+
"""
|
84 |
+
Parse and returns command-line args
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
argparse.Namespace: the parsed arguments
|
88 |
+
"""
|
89 |
+
parser = ArgumentParser()
|
90 |
+
parser.add_argument(
|
91 |
+
"--input_csv",
|
92 |
+
default="ablations_metrics_20210311.csv",
|
93 |
+
type=str,
|
94 |
+
help="CSV containing the results of the ablation study",
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--output_dir",
|
98 |
+
default=None,
|
99 |
+
type=str,
|
100 |
+
help="Output directory",
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--dpi",
|
104 |
+
default=200,
|
105 |
+
type=int,
|
106 |
+
help="DPI for the output images",
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"--n_bs",
|
110 |
+
default=1e6,
|
111 |
+
type=int,
|
112 |
+
help="Number of bootrstrap samples",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--alpha",
|
116 |
+
default=0.99,
|
117 |
+
type=float,
|
118 |
+
help="Confidence level",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--bs_seed",
|
122 |
+
default=17,
|
123 |
+
type=int,
|
124 |
+
help="Bootstrap random seed, for reproducibility",
|
125 |
+
)
|
126 |
+
|
127 |
+
return parser.parse_args()
|
128 |
+
|
129 |
+
|
130 |
+
def trim_mean_wrapper(a):
|
131 |
+
return trim_mean(a, proportiontocut=0.2)
|
132 |
+
|
133 |
+
|
134 |
+
def find_model_pairs(technique, model_feats):
|
135 |
+
model_pairs = []
|
136 |
+
for mi in df.loc[df[technique]].model_feats.unique():
|
137 |
+
for mj in df.model_feats.unique():
|
138 |
+
if mj == mi:
|
139 |
+
continue
|
140 |
+
|
141 |
+
if df.loc[df.model_feats == mj, technique].unique()[0]:
|
142 |
+
continue
|
143 |
+
|
144 |
+
is_pair = True
|
145 |
+
for f in model_feats:
|
146 |
+
if f == technique:
|
147 |
+
continue
|
148 |
+
elif (
|
149 |
+
df.loc[df.model_feats == mj, f].unique()[0]
|
150 |
+
!= df.loc[df.model_feats == mi, f].unique()[0]
|
151 |
+
):
|
152 |
+
is_pair = False
|
153 |
+
break
|
154 |
+
else:
|
155 |
+
pass
|
156 |
+
if is_pair:
|
157 |
+
model_pairs.append((mi, mj))
|
158 |
+
break
|
159 |
+
return model_pairs
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
# -----------------------------
|
164 |
+
# ----- Parse arguments -----
|
165 |
+
# -----------------------------
|
166 |
+
args = parsed_args()
|
167 |
+
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
|
168 |
+
|
169 |
+
# Determine output dir
|
170 |
+
if args.output_dir is None:
|
171 |
+
output_dir = Path(os.environ["SLURM_TMPDIR"])
|
172 |
+
else:
|
173 |
+
output_dir = Path(args.output_dir)
|
174 |
+
if not output_dir.exists():
|
175 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
176 |
+
|
177 |
+
# Store args
|
178 |
+
output_yml = output_dir / "bootstrap_summary.yml"
|
179 |
+
with open(output_yml, "w") as f:
|
180 |
+
yaml.dump(vars(args), f)
|
181 |
+
|
182 |
+
# Read CSV
|
183 |
+
df = pd.read_csv(args.input_csv, index_col="model_img_idx")
|
184 |
+
|
185 |
+
# Build data set
|
186 |
+
dfbs = pd.DataFrame(columns=["diff", "technique", "metric"])
|
187 |
+
for technique in model_feats:
|
188 |
+
|
189 |
+
# Get pairs
|
190 |
+
model_pairs = find_model_pairs(technique, model_feats)
|
191 |
+
|
192 |
+
# Compute differences
|
193 |
+
for m_with, m_without in model_pairs:
|
194 |
+
df_with = df.loc[df.model_feats == m_with]
|
195 |
+
df_without = df.loc[df.model_feats == m_without]
|
196 |
+
for metric in dict_metrics["key_metrics"]:
|
197 |
+
diff = (
|
198 |
+
df_with.sort_values(by="img_idx")[metric].values
|
199 |
+
- df_without.sort_values(by="img_idx")[metric].values
|
200 |
+
)
|
201 |
+
dfm = pd.DataFrame.from_dict(
|
202 |
+
{"metric": metric, "technique": technique, "diff": diff}
|
203 |
+
)
|
204 |
+
dfbs = dfbs.append(dfm, ignore_index=True)
|
205 |
+
|
206 |
+
### Plot
|
207 |
+
|
208 |
+
# Set up plot
|
209 |
+
sns.reset_orig()
|
210 |
+
sns.set(style="whitegrid")
|
211 |
+
plt.rcParams.update({"font.family": "serif"})
|
212 |
+
plt.rcParams.update(
|
213 |
+
{
|
214 |
+
"font.serif": [
|
215 |
+
"Computer Modern Roman",
|
216 |
+
"Times New Roman",
|
217 |
+
"Utopia",
|
218 |
+
"New Century Schoolbook",
|
219 |
+
"Century Schoolbook L",
|
220 |
+
"ITC Bookman",
|
221 |
+
"Bookman",
|
222 |
+
"Times",
|
223 |
+
"Palatino",
|
224 |
+
"Charter",
|
225 |
+
"serif" "Bitstream Vera Serif",
|
226 |
+
"DejaVu Serif",
|
227 |
+
]
|
228 |
+
}
|
229 |
+
)
|
230 |
+
|
231 |
+
fig, axes = plt.subplots(
|
232 |
+
nrows=1, ncols=3, sharey=True, dpi=args.dpi, figsize=(9, 3)
|
233 |
+
)
|
234 |
+
|
235 |
+
metrics = ["error", "f05", "edge_coherence"]
|
236 |
+
dict_ci = {m: {} for m in metrics}
|
237 |
+
|
238 |
+
for idx, metric in enumerate(dict_metrics["key_metrics"]):
|
239 |
+
|
240 |
+
ax = sns.pointplot(
|
241 |
+
ax=axes[idx],
|
242 |
+
data=dfbs.loc[dfbs.metric.isin(["error", "f05", "edge_coherence"])],
|
243 |
+
order=dict_techniques.keys(),
|
244 |
+
x="diff",
|
245 |
+
y="technique",
|
246 |
+
hue="metric",
|
247 |
+
hue_order=[metric],
|
248 |
+
markers=dict_markers[metric],
|
249 |
+
palette=[palette_metrics[idx]],
|
250 |
+
errwidth=1.5,
|
251 |
+
scale=0.6,
|
252 |
+
join=False,
|
253 |
+
estimator=trim_mean_wrapper,
|
254 |
+
ci=int(args.alpha * 100),
|
255 |
+
n_boot=args.n_bs,
|
256 |
+
seed=args.bs_seed,
|
257 |
+
)
|
258 |
+
|
259 |
+
# Retrieve confidence intervals and update results dictionary
|
260 |
+
for line, technique in zip(ax.lines, dict_techniques.keys()):
|
261 |
+
dict_ci[metric].update(
|
262 |
+
{
|
263 |
+
technique: {
|
264 |
+
"20_trimmed_mean": float(
|
265 |
+
trim_mean_wrapper(
|
266 |
+
dfbs.loc[
|
267 |
+
(dfbs.technique == technique)
|
268 |
+
& (dfbs.metric == metrics[idx]),
|
269 |
+
"diff",
|
270 |
+
].values
|
271 |
+
)
|
272 |
+
),
|
273 |
+
"ci_left": float(line.get_xdata()[0]),
|
274 |
+
"ci_right": float(line.get_xdata()[1]),
|
275 |
+
}
|
276 |
+
}
|
277 |
+
)
|
278 |
+
|
279 |
+
leg_handles, leg_labels = ax.get_legend_handles_labels()
|
280 |
+
|
281 |
+
# Change spines
|
282 |
+
sns.despine(left=True, bottom=True)
|
283 |
+
|
284 |
+
# Set Y-label
|
285 |
+
ax.set_ylabel(None)
|
286 |
+
|
287 |
+
# Y-tick labels
|
288 |
+
ax.set_yticklabels(list(dict_techniques.values()), fontsize="medium")
|
289 |
+
|
290 |
+
# Set X-label
|
291 |
+
ax.set_xlabel(None)
|
292 |
+
|
293 |
+
# X-ticks
|
294 |
+
xticks = ax.get_xticks()
|
295 |
+
xticklabels = xticks
|
296 |
+
ax.set_xticks(xticks)
|
297 |
+
ax.set_xticklabels(xticklabels, fontsize="small")
|
298 |
+
|
299 |
+
# Y-lim
|
300 |
+
display2data = ax.transData.inverted()
|
301 |
+
ax2display = ax.transAxes
|
302 |
+
_, y_bottom = display2data.transform(ax.transAxes.transform((0.0, 0.02)))
|
303 |
+
_, y_top = display2data.transform(ax.transAxes.transform((0.0, 0.98)))
|
304 |
+
ax.set_ylim(bottom=y_bottom, top=y_top)
|
305 |
+
|
306 |
+
# Draw line at H0
|
307 |
+
y = np.arange(ax.get_ylim()[1], ax.get_ylim()[0], 0.1)
|
308 |
+
x = 0.0 * np.ones(y.shape[0])
|
309 |
+
ax.plot(x, y, linestyle=":", linewidth=1.5, color="black")
|
310 |
+
|
311 |
+
# Draw gray area
|
312 |
+
xlim = ax.get_xlim()
|
313 |
+
ylim = ax.get_ylim()
|
314 |
+
if metric == "error":
|
315 |
+
x0 = xlim[0]
|
316 |
+
width = np.abs(x0)
|
317 |
+
else:
|
318 |
+
x0 = 0.0
|
319 |
+
width = np.abs(xlim[1])
|
320 |
+
trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
|
321 |
+
rect = mpatches.Rectangle(
|
322 |
+
xy=(x0, 0.0),
|
323 |
+
width=width,
|
324 |
+
height=1,
|
325 |
+
transform=trans,
|
326 |
+
linewidth=0.0,
|
327 |
+
edgecolor="none",
|
328 |
+
facecolor="gray",
|
329 |
+
alpha=0.05,
|
330 |
+
)
|
331 |
+
ax.add_patch(rect)
|
332 |
+
|
333 |
+
# Legend
|
334 |
+
leg_handles, leg_labels = ax.get_legend_handles_labels()
|
335 |
+
leg_labels = [dict_metrics["names"][metric] for metric in leg_labels]
|
336 |
+
leg = ax.legend(
|
337 |
+
handles=leg_handles,
|
338 |
+
labels=leg_labels,
|
339 |
+
loc="center",
|
340 |
+
title="",
|
341 |
+
bbox_to_anchor=(-0.2, 1.05, 1.0, 0.0),
|
342 |
+
framealpha=1.0,
|
343 |
+
frameon=False,
|
344 |
+
handletextpad=-0.2,
|
345 |
+
)
|
346 |
+
|
347 |
+
# Set X-label (title) │
|
348 |
+
fig.suptitle(
|
349 |
+
"20 % trimmed mean difference and bootstrapped confidence intervals",
|
350 |
+
y=0.0,
|
351 |
+
fontsize="medium",
|
352 |
+
)
|
353 |
+
|
354 |
+
# Save figure
|
355 |
+
output_fig = output_dir / "bootstrap_summary.png"
|
356 |
+
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
|
357 |
+
|
358 |
+
# Store results
|
359 |
+
output_results = output_dir / "bootstrap_summary_results.yml"
|
360 |
+
with open(output_results, "w") as f:
|
361 |
+
yaml.dump(dict_ci, f)
|
figures/human_evaluation.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script plots the result of the human evaluation on Amazon Mechanical Turk, where
|
3 |
+
human participants chose between an image from ClimateGAN or from a different method.
|
4 |
+
"""
|
5 |
+
print("Imports...", end="")
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
import os
|
8 |
+
import yaml
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import seaborn as sns
|
12 |
+
from pathlib import Path
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
|
16 |
+
# -----------------------
|
17 |
+
# ----- Constants -----
|
18 |
+
# -----------------------
|
19 |
+
|
20 |
+
comparables_dict = {
|
21 |
+
"munit_flooded": "MUNIT",
|
22 |
+
"cyclegan": "CycleGAN",
|
23 |
+
"instagan": "InstaGAN",
|
24 |
+
"instagan_copypaste": "Mask-InstaGAN",
|
25 |
+
"painted_ground": "Painted ground",
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
# Colors
|
30 |
+
palette_colorblind = sns.color_palette("colorblind")
|
31 |
+
color_climategan = palette_colorblind[9]
|
32 |
+
|
33 |
+
palette_colorblind = sns.color_palette("colorblind")
|
34 |
+
color_munit = palette_colorblind[1]
|
35 |
+
color_cyclegan = palette_colorblind[2]
|
36 |
+
color_instagan = palette_colorblind[3]
|
37 |
+
color_maskinstagan = palette_colorblind[6]
|
38 |
+
color_paintedground = palette_colorblind[8]
|
39 |
+
palette_comparables = [
|
40 |
+
color_munit,
|
41 |
+
color_cyclegan,
|
42 |
+
color_instagan,
|
43 |
+
color_maskinstagan,
|
44 |
+
color_paintedground,
|
45 |
+
]
|
46 |
+
palette_comparables_light = [
|
47 |
+
sns.light_palette(color, n_colors=3)[1] for color in palette_comparables
|
48 |
+
]
|
49 |
+
|
50 |
+
|
51 |
+
def parsed_args():
|
52 |
+
"""
|
53 |
+
Parse and returns command-line args
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
argparse.Namespace: the parsed arguments
|
57 |
+
"""
|
58 |
+
parser = ArgumentParser()
|
59 |
+
parser.add_argument(
|
60 |
+
"--input_csv",
|
61 |
+
default="amt_omni-vs-other.csv",
|
62 |
+
type=str,
|
63 |
+
help="CSV containing the results of the human evaluation, pre-processed",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--output_dir",
|
67 |
+
default=None,
|
68 |
+
type=str,
|
69 |
+
help="Output directory",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--dpi",
|
73 |
+
default=200,
|
74 |
+
type=int,
|
75 |
+
help="DPI for the output images",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--n_bs",
|
79 |
+
default=1e6,
|
80 |
+
type=int,
|
81 |
+
help="Number of bootrstrap samples",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--bs_seed",
|
85 |
+
default=17,
|
86 |
+
type=int,
|
87 |
+
help="Bootstrap random seed, for reproducibility",
|
88 |
+
)
|
89 |
+
|
90 |
+
return parser.parse_args()
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
# -----------------------------
|
95 |
+
# ----- Parse arguments -----
|
96 |
+
# -----------------------------
|
97 |
+
args = parsed_args()
|
98 |
+
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
|
99 |
+
|
100 |
+
# Determine output dir
|
101 |
+
if args.output_dir is None:
|
102 |
+
output_dir = Path(os.environ["SLURM_TMPDIR"])
|
103 |
+
else:
|
104 |
+
output_dir = Path(args.output_dir)
|
105 |
+
if not output_dir.exists():
|
106 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
107 |
+
|
108 |
+
# Store args
|
109 |
+
output_yml = output_dir / "args_human_evaluation.yml"
|
110 |
+
with open(output_yml, "w") as f:
|
111 |
+
yaml.dump(vars(args), f)
|
112 |
+
|
113 |
+
# Read CSV
|
114 |
+
df = pd.read_csv(args.input_csv)
|
115 |
+
|
116 |
+
# Sort Y labels
|
117 |
+
comparables = df.comparable.unique()
|
118 |
+
is_climategan_sum = [
|
119 |
+
df.loc[df.comparable == c, "climategan"].sum() for c in comparables
|
120 |
+
]
|
121 |
+
comparables = comparables[np.argsort(is_climategan_sum)[::-1]]
|
122 |
+
|
123 |
+
# Plot setup
|
124 |
+
sns.set(style="whitegrid")
|
125 |
+
plt.rcParams.update({"font.family": "serif"})
|
126 |
+
plt.rcParams.update(
|
127 |
+
{
|
128 |
+
"font.serif": [
|
129 |
+
"Computer Modern Roman",
|
130 |
+
"Times New Roman",
|
131 |
+
"Utopia",
|
132 |
+
"New Century Schoolbook",
|
133 |
+
"Century Schoolbook L",
|
134 |
+
"ITC Bookman",
|
135 |
+
"Bookman",
|
136 |
+
"Times",
|
137 |
+
"Palatino",
|
138 |
+
"Charter",
|
139 |
+
"serif" "Bitstream Vera Serif",
|
140 |
+
"DejaVu Serif",
|
141 |
+
]
|
142 |
+
}
|
143 |
+
)
|
144 |
+
fontsize = "medium"
|
145 |
+
|
146 |
+
# Initialize the matplotlib figure
|
147 |
+
fig, ax = plt.subplots(figsize=(10.5, 3), dpi=args.dpi)
|
148 |
+
|
149 |
+
# Plot the total (right)
|
150 |
+
sns.barplot(
|
151 |
+
data=df.loc[df.is_valid],
|
152 |
+
x="is_valid",
|
153 |
+
y="comparable",
|
154 |
+
order=comparables,
|
155 |
+
orient="h",
|
156 |
+
label="comparable",
|
157 |
+
palette=palette_comparables_light,
|
158 |
+
ci=None,
|
159 |
+
)
|
160 |
+
|
161 |
+
# Plot the left
|
162 |
+
sns.barplot(
|
163 |
+
data=df.loc[df.is_valid],
|
164 |
+
x="climategan",
|
165 |
+
y="comparable",
|
166 |
+
order=comparables,
|
167 |
+
orient="h",
|
168 |
+
label="climategan",
|
169 |
+
color=color_climategan,
|
170 |
+
ci=99,
|
171 |
+
n_boot=args.n_bs,
|
172 |
+
seed=args.bs_seed,
|
173 |
+
errcolor="black",
|
174 |
+
errwidth=1.5,
|
175 |
+
capsize=0.1,
|
176 |
+
)
|
177 |
+
|
178 |
+
# Draw line at 0.5
|
179 |
+
y = np.arange(ax.get_ylim()[1] + 0.1, ax.get_ylim()[0], 0.1)
|
180 |
+
x = 0.5 * np.ones(y.shape[0])
|
181 |
+
ax.plot(x, y, linestyle=":", linewidth=1.5, color="black")
|
182 |
+
|
183 |
+
# Change Y-Tick labels
|
184 |
+
yticklabels = [comparables_dict[ytick.get_text()] for ytick in ax.get_yticklabels()]
|
185 |
+
yticklabels_text = ax.set_yticklabels(
|
186 |
+
yticklabels, fontsize=fontsize, horizontalalignment="right", x=0.96
|
187 |
+
)
|
188 |
+
for ytl in yticklabels_text:
|
189 |
+
ax.add_artist(ytl)
|
190 |
+
|
191 |
+
# Remove Y-label
|
192 |
+
ax.set_ylabel(ylabel="")
|
193 |
+
|
194 |
+
# Change X-Tick labels
|
195 |
+
xlim = [0.0, 1.1]
|
196 |
+
xticks = np.arange(xlim[0], xlim[1], 0.1)
|
197 |
+
ax.set(xticks=xticks)
|
198 |
+
plt.setp(ax.get_xticklabels(), fontsize=fontsize)
|
199 |
+
|
200 |
+
# Set X-label
|
201 |
+
ax.set_xlabel(None)
|
202 |
+
|
203 |
+
# Change spines
|
204 |
+
sns.despine(left=True, bottom=True)
|
205 |
+
|
206 |
+
# Save figure
|
207 |
+
output_fig = output_dir / "human_evaluation_rate_climategan.png"
|
208 |
+
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
|
figures/labels.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This scripts plots images from the Masker test set overlaid with their labels.
|
3 |
+
"""
|
4 |
+
print("Imports...", end="")
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
import os
|
7 |
+
import yaml
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import seaborn as sns
|
11 |
+
from pathlib import Path
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import matplotlib.patches as mpatches
|
14 |
+
|
15 |
+
import sys
|
16 |
+
|
17 |
+
sys.path.append("../")
|
18 |
+
|
19 |
+
from eval_masker import crop_and_resize
|
20 |
+
|
21 |
+
|
22 |
+
# -----------------------
|
23 |
+
# ----- Constants -----
|
24 |
+
# -----------------------
|
25 |
+
|
26 |
+
# Colors
|
27 |
+
colorblind_palette = sns.color_palette("colorblind")
|
28 |
+
color_cannot = colorblind_palette[1]
|
29 |
+
color_must = colorblind_palette[2]
|
30 |
+
color_may = colorblind_palette[7]
|
31 |
+
color_pred = colorblind_palette[4]
|
32 |
+
|
33 |
+
icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
|
34 |
+
color_tp = icefire[0]
|
35 |
+
color_tn = icefire[1]
|
36 |
+
color_fp = icefire[4]
|
37 |
+
color_fn = icefire[3]
|
38 |
+
|
39 |
+
|
40 |
+
def parsed_args():
|
41 |
+
"""
|
42 |
+
Parse and returns command-line args
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
argparse.Namespace: the parsed arguments
|
46 |
+
"""
|
47 |
+
parser = ArgumentParser()
|
48 |
+
parser.add_argument(
|
49 |
+
"--input_csv",
|
50 |
+
default="ablations_metrics_20210311.csv",
|
51 |
+
type=str,
|
52 |
+
help="CSV containing the results of the ablation study",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--output_dir",
|
56 |
+
default=None,
|
57 |
+
type=str,
|
58 |
+
help="Output directory",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--masker_test_set_dir",
|
62 |
+
default=None,
|
63 |
+
type=str,
|
64 |
+
help="Directory containing the test images",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--images",
|
68 |
+
nargs="+",
|
69 |
+
help="List of image file names to plot",
|
70 |
+
default=[],
|
71 |
+
type=str,
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--dpi",
|
75 |
+
default=200,
|
76 |
+
type=int,
|
77 |
+
help="DPI for the output images",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--alpha",
|
81 |
+
default=0.5,
|
82 |
+
type=float,
|
83 |
+
help="Transparency of labels shade",
|
84 |
+
)
|
85 |
+
|
86 |
+
return parser.parse_args()
|
87 |
+
|
88 |
+
|
89 |
+
def map_color(arr, input_color, output_color, rtol=1e-09):
|
90 |
+
"""
|
91 |
+
Maps one color to another
|
92 |
+
"""
|
93 |
+
input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
|
94 |
+
output = arr.copy()
|
95 |
+
output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
|
96 |
+
return output
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
# -----------------------------
|
101 |
+
# ----- Parse arguments -----
|
102 |
+
# -----------------------------
|
103 |
+
args = parsed_args()
|
104 |
+
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
|
105 |
+
|
106 |
+
# Determine output dir
|
107 |
+
if args.output_dir is None:
|
108 |
+
output_dir = Path(os.environ["SLURM_TMPDIR"])
|
109 |
+
else:
|
110 |
+
output_dir = Path(args.output_dir)
|
111 |
+
if not output_dir.exists():
|
112 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
113 |
+
|
114 |
+
# Store args
|
115 |
+
output_yml = output_dir / "labels.yml"
|
116 |
+
with open(output_yml, "w") as f:
|
117 |
+
yaml.dump(vars(args), f)
|
118 |
+
|
119 |
+
# Data dirs
|
120 |
+
imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
|
121 |
+
labels_path = Path(args.masker_test_set_dir) / "labels"
|
122 |
+
|
123 |
+
# Read CSV
|
124 |
+
df = pd.read_csv(args.input_csv, index_col="model_img_idx")
|
125 |
+
|
126 |
+
# Set up plot
|
127 |
+
sns.reset_orig()
|
128 |
+
sns.set(style="whitegrid")
|
129 |
+
plt.rcParams.update({"font.family": "serif"})
|
130 |
+
plt.rcParams.update(
|
131 |
+
{
|
132 |
+
"font.serif": [
|
133 |
+
"Computer Modern Roman",
|
134 |
+
"Times New Roman",
|
135 |
+
"Utopia",
|
136 |
+
"New Century Schoolbook",
|
137 |
+
"Century Schoolbook L",
|
138 |
+
"ITC Bookman",
|
139 |
+
"Bookman",
|
140 |
+
"Times",
|
141 |
+
"Palatino",
|
142 |
+
"Charter",
|
143 |
+
"serif" "Bitstream Vera Serif",
|
144 |
+
"DejaVu Serif",
|
145 |
+
]
|
146 |
+
}
|
147 |
+
)
|
148 |
+
|
149 |
+
fig, axes = plt.subplots(
|
150 |
+
nrows=1, ncols=len(args.images), dpi=args.dpi, figsize=(len(args.images) * 5, 5)
|
151 |
+
)
|
152 |
+
|
153 |
+
for idx, img_filename in enumerate(args.images):
|
154 |
+
|
155 |
+
# Read images
|
156 |
+
img_path = imgs_orig_path / img_filename
|
157 |
+
label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
|
158 |
+
img, label = crop_and_resize(img_path, label_path)
|
159 |
+
|
160 |
+
# Map label colors
|
161 |
+
label_colmap = label.astype(float)
|
162 |
+
label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
|
163 |
+
label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
|
164 |
+
label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
|
165 |
+
|
166 |
+
ax = axes[idx]
|
167 |
+
ax.imshow(img)
|
168 |
+
ax.imshow(label_colmap, alpha=args.alpha)
|
169 |
+
ax.axis("off")
|
170 |
+
|
171 |
+
# Legend
|
172 |
+
handles = []
|
173 |
+
lw = 1.0
|
174 |
+
handles.append(
|
175 |
+
mpatches.Patch(
|
176 |
+
facecolor=color_must, label="must", linewidth=lw, alpha=args.alpha
|
177 |
+
)
|
178 |
+
)
|
179 |
+
handles.append(
|
180 |
+
mpatches.Patch(facecolor=color_may, label="may", linewidth=lw, alpha=args.alpha)
|
181 |
+
)
|
182 |
+
handles.append(
|
183 |
+
mpatches.Patch(
|
184 |
+
facecolor=color_cannot, label="cannot", linewidth=lw, alpha=args.alpha
|
185 |
+
)
|
186 |
+
)
|
187 |
+
labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"]
|
188 |
+
fig.legend(
|
189 |
+
handles=handles,
|
190 |
+
labels=labels,
|
191 |
+
loc="upper center",
|
192 |
+
bbox_to_anchor=(0.0, 0.85, 1.0, 0.075),
|
193 |
+
ncol=len(args.images),
|
194 |
+
fontsize="medium",
|
195 |
+
frameon=False,
|
196 |
+
)
|
197 |
+
|
198 |
+
# Save figure
|
199 |
+
output_fig = output_dir / "labels.png"
|
200 |
+
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
|
figures/metrics.py
ADDED
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This scripts plots examples of the images that get best and worse metrics
|
3 |
+
"""
|
4 |
+
print("Imports...", end="")
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import matplotlib.patches as mpatches
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import seaborn as sns
|
15 |
+
import yaml
|
16 |
+
from imageio import imread
|
17 |
+
from skimage.color import rgba2rgb
|
18 |
+
from sklearn.metrics.pairwise import euclidean_distances
|
19 |
+
|
20 |
+
sys.path.append("../")
|
21 |
+
|
22 |
+
from climategan.data import encode_mask_label
|
23 |
+
from climategan.eval_metrics import edges_coherence_std_min
|
24 |
+
from eval_masker import crop_and_resize
|
25 |
+
|
26 |
+
# -----------------------
|
27 |
+
# ----- Constants -----
|
28 |
+
# -----------------------
|
29 |
+
|
30 |
+
# Metrics
|
31 |
+
metrics = ["error", "f05", "edge_coherence"]
|
32 |
+
|
33 |
+
dict_metrics = {
|
34 |
+
"names": {
|
35 |
+
"tpr": "TPR, Recall, Sensitivity",
|
36 |
+
"tnr": "TNR, Specificity, Selectivity",
|
37 |
+
"fpr": "FPR",
|
38 |
+
"fpt": "False positives relative to image size",
|
39 |
+
"fnr": "FNR, Miss rate",
|
40 |
+
"fnt": "False negatives relative to image size",
|
41 |
+
"mpr": "May positive rate (MPR)",
|
42 |
+
"mnr": "May negative rate (MNR)",
|
43 |
+
"accuracy": "Accuracy (ignoring may)",
|
44 |
+
"error": "Error",
|
45 |
+
"f05": "F05 score",
|
46 |
+
"precision": "Precision",
|
47 |
+
"edge_coherence": "Edge coherence",
|
48 |
+
"accuracy_must_may": "Accuracy (ignoring cannot)",
|
49 |
+
},
|
50 |
+
"key_metrics": ["error", "f05", "edge_coherence"],
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
# Colors
|
55 |
+
colorblind_palette = sns.color_palette("colorblind")
|
56 |
+
color_cannot = colorblind_palette[1]
|
57 |
+
color_must = colorblind_palette[2]
|
58 |
+
color_may = colorblind_palette[7]
|
59 |
+
color_pred = colorblind_palette[4]
|
60 |
+
|
61 |
+
icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
|
62 |
+
color_tp = icefire[0]
|
63 |
+
color_tn = icefire[1]
|
64 |
+
color_fp = icefire[4]
|
65 |
+
color_fn = icefire[3]
|
66 |
+
|
67 |
+
|
68 |
+
def parsed_args():
|
69 |
+
"""
|
70 |
+
Parse and returns command-line args
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
argparse.Namespace: the parsed arguments
|
74 |
+
"""
|
75 |
+
parser = ArgumentParser()
|
76 |
+
parser.add_argument(
|
77 |
+
"--input_csv",
|
78 |
+
default="ablations_metrics_20210311.csv",
|
79 |
+
type=str,
|
80 |
+
help="CSV containing the results of the ablation study",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--output_dir",
|
84 |
+
default=None,
|
85 |
+
type=str,
|
86 |
+
help="Output directory",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--models_log_path",
|
90 |
+
default=None,
|
91 |
+
type=str,
|
92 |
+
help="Path containing the log files of the models",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--masker_test_set_dir",
|
96 |
+
default=None,
|
97 |
+
type=str,
|
98 |
+
help="Directory containing the test images",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--best_model",
|
102 |
+
default="dada, msd_spade, pseudo",
|
103 |
+
type=str,
|
104 |
+
help="The string identifier of the best model",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--dpi",
|
108 |
+
default=200,
|
109 |
+
type=int,
|
110 |
+
help="DPI for the output images",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--alpha",
|
114 |
+
default=0.5,
|
115 |
+
type=float,
|
116 |
+
help="Transparency of labels shade",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--percentile",
|
120 |
+
default=0.05,
|
121 |
+
type=float,
|
122 |
+
help="Transparency of labels shade",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--seed",
|
126 |
+
default=None,
|
127 |
+
type=int,
|
128 |
+
help="Bootstrap random seed, for reproducibility",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--no_images",
|
132 |
+
action="store_true",
|
133 |
+
default=False,
|
134 |
+
help="Do not generate images",
|
135 |
+
)
|
136 |
+
|
137 |
+
return parser.parse_args()
|
138 |
+
|
139 |
+
|
140 |
+
def map_color(arr, input_color, output_color, rtol=1e-09):
|
141 |
+
"""
|
142 |
+
Maps one color to another
|
143 |
+
"""
|
144 |
+
input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
|
145 |
+
output = arr.copy()
|
146 |
+
output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
|
147 |
+
return output
|
148 |
+
|
149 |
+
|
150 |
+
def plot_labels(ax, img, label, img_id, do_legend):
|
151 |
+
label_colmap = label.astype(float)
|
152 |
+
label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
|
153 |
+
label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
|
154 |
+
label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
|
155 |
+
|
156 |
+
ax.imshow(img)
|
157 |
+
ax.imshow(label_colmap, alpha=0.5)
|
158 |
+
ax.axis("off")
|
159 |
+
|
160 |
+
# Annotation
|
161 |
+
ax.annotate(
|
162 |
+
xy=(0.05, 0.95),
|
163 |
+
xycoords="axes fraction",
|
164 |
+
xytext=(0.05, 0.95),
|
165 |
+
textcoords="axes fraction",
|
166 |
+
text=img_id,
|
167 |
+
fontsize="x-large",
|
168 |
+
verticalalignment="top",
|
169 |
+
color="white",
|
170 |
+
)
|
171 |
+
|
172 |
+
# Legend
|
173 |
+
if do_legend:
|
174 |
+
handles = []
|
175 |
+
lw = 1.0
|
176 |
+
handles.append(
|
177 |
+
mpatches.Patch(facecolor=color_must, label="must", linewidth=lw, alpha=0.66)
|
178 |
+
)
|
179 |
+
handles.append(
|
180 |
+
mpatches.Patch(facecolor=color_may, label="must", linewidth=lw, alpha=0.66)
|
181 |
+
)
|
182 |
+
handles.append(
|
183 |
+
mpatches.Patch(
|
184 |
+
facecolor=color_cannot, label="must", linewidth=lw, alpha=0.66
|
185 |
+
)
|
186 |
+
)
|
187 |
+
labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"]
|
188 |
+
ax.legend(
|
189 |
+
handles=handles,
|
190 |
+
labels=labels,
|
191 |
+
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
|
192 |
+
ncol=3,
|
193 |
+
mode="expand",
|
194 |
+
fontsize="xx-small",
|
195 |
+
frameon=False,
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def plot_pred(ax, img, pred, img_id, do_legend):
|
200 |
+
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
|
201 |
+
|
202 |
+
pred_colmap = pred.astype(float)
|
203 |
+
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
|
204 |
+
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
|
205 |
+
pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma
|
206 |
+
|
207 |
+
ax.imshow(img)
|
208 |
+
ax.imshow(pred_colmap_ma, alpha=0.5)
|
209 |
+
ax.axis("off")
|
210 |
+
|
211 |
+
# Annotation
|
212 |
+
ax.annotate(
|
213 |
+
xy=(0.05, 0.95),
|
214 |
+
xycoords="axes fraction",
|
215 |
+
xytext=(0.05, 0.95),
|
216 |
+
textcoords="axes fraction",
|
217 |
+
text=img_id,
|
218 |
+
fontsize="x-large",
|
219 |
+
verticalalignment="top",
|
220 |
+
color="white",
|
221 |
+
)
|
222 |
+
|
223 |
+
# Legend
|
224 |
+
if do_legend:
|
225 |
+
handles = []
|
226 |
+
lw = 1.0
|
227 |
+
handles.append(
|
228 |
+
mpatches.Patch(facecolor=color_pred, label="must", linewidth=lw, alpha=0.66)
|
229 |
+
)
|
230 |
+
labels = ["Prediction"]
|
231 |
+
ax.legend(
|
232 |
+
handles=handles,
|
233 |
+
labels=labels,
|
234 |
+
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
|
235 |
+
ncol=3,
|
236 |
+
mode="expand",
|
237 |
+
fontsize="xx-small",
|
238 |
+
frameon=False,
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
def plot_correct_incorrect(ax, img_filename, img, label, img_id, do_legend):
|
243 |
+
# FP
|
244 |
+
fp_map = imread(
|
245 |
+
model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem)
|
246 |
+
)
|
247 |
+
fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3))
|
248 |
+
|
249 |
+
fp_map_colmap = fp_map.astype(float)
|
250 |
+
fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp)
|
251 |
+
|
252 |
+
# FN
|
253 |
+
fn_map = imread(
|
254 |
+
model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem)
|
255 |
+
)
|
256 |
+
fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3))
|
257 |
+
|
258 |
+
fn_map_colmap = fn_map.astype(float)
|
259 |
+
fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn)
|
260 |
+
|
261 |
+
# TP
|
262 |
+
tp_map = imread(
|
263 |
+
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem)
|
264 |
+
)
|
265 |
+
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
|
266 |
+
|
267 |
+
tp_map_colmap = tp_map.astype(float)
|
268 |
+
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
|
269 |
+
|
270 |
+
# TN
|
271 |
+
tn_map = imread(
|
272 |
+
model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem)
|
273 |
+
)
|
274 |
+
tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3))
|
275 |
+
|
276 |
+
tn_map_colmap = tn_map.astype(float)
|
277 |
+
tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn)
|
278 |
+
|
279 |
+
label_colmap = label.astype(float)
|
280 |
+
label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
|
281 |
+
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may)
|
282 |
+
label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma
|
283 |
+
|
284 |
+
# Combine masks
|
285 |
+
maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap
|
286 |
+
maps_ma = np.ma.masked_equal(maps, (0, 0, 0))
|
287 |
+
maps_ma = maps_ma.mask * img + maps_ma
|
288 |
+
|
289 |
+
ax.imshow(img)
|
290 |
+
ax.imshow(label_colmap_ma, alpha=0.5)
|
291 |
+
ax.imshow(maps_ma, alpha=0.5)
|
292 |
+
ax.axis("off")
|
293 |
+
|
294 |
+
# Annotation
|
295 |
+
ax.annotate(
|
296 |
+
xy=(0.05, 0.95),
|
297 |
+
xycoords="axes fraction",
|
298 |
+
xytext=(0.05, 0.95),
|
299 |
+
textcoords="axes fraction",
|
300 |
+
text=img_id,
|
301 |
+
fontsize="x-large",
|
302 |
+
verticalalignment="top",
|
303 |
+
color="white",
|
304 |
+
)
|
305 |
+
|
306 |
+
# Legend
|
307 |
+
if do_legend:
|
308 |
+
handles = []
|
309 |
+
lw = 1.0
|
310 |
+
handles.append(
|
311 |
+
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
|
312 |
+
)
|
313 |
+
handles.append(
|
314 |
+
mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66)
|
315 |
+
)
|
316 |
+
handles.append(
|
317 |
+
mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66)
|
318 |
+
)
|
319 |
+
handles.append(
|
320 |
+
mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66)
|
321 |
+
)
|
322 |
+
handles.append(
|
323 |
+
mpatches.Patch(
|
324 |
+
facecolor=color_may, label="May-be-flooded", linewidth=lw, alpha=0.66
|
325 |
+
)
|
326 |
+
)
|
327 |
+
labels = ["TP", "TN", "FP", "FN", "May-be-flooded"]
|
328 |
+
ax.legend(
|
329 |
+
handles=handles,
|
330 |
+
labels=labels,
|
331 |
+
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
|
332 |
+
ncol=5,
|
333 |
+
mode="expand",
|
334 |
+
fontsize="xx-small",
|
335 |
+
frameon=False,
|
336 |
+
)
|
337 |
+
|
338 |
+
|
339 |
+
def plot_edge_coherence(ax, img, label, pred, img_id, do_legend):
|
340 |
+
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
|
341 |
+
|
342 |
+
ec, pred_ec, label_ec = edges_coherence_std_min(
|
343 |
+
np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood"))
|
344 |
+
)
|
345 |
+
|
346 |
+
##################
|
347 |
+
# Edge distances #
|
348 |
+
##################
|
349 |
+
|
350 |
+
# Location of edges
|
351 |
+
pred_ec_coord = np.argwhere(pred_ec > 0)
|
352 |
+
label_ec_coord = np.argwhere(label_ec > 0)
|
353 |
+
|
354 |
+
# Normalized pairwise distances between pred and label
|
355 |
+
dist_mat = np.divide(
|
356 |
+
euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0]
|
357 |
+
)
|
358 |
+
|
359 |
+
# Standard deviation of the minimum distance from pred to label
|
360 |
+
min_dist = np.min(dist_mat, axis=1) # noqa: F841
|
361 |
+
|
362 |
+
#############
|
363 |
+
# Make plot #
|
364 |
+
#############
|
365 |
+
|
366 |
+
pred_ec = np.tile(
|
367 |
+
np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
|
368 |
+
)
|
369 |
+
pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred)
|
370 |
+
pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) # noqa: F841
|
371 |
+
|
372 |
+
label_ec = np.tile(
|
373 |
+
np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
|
374 |
+
)
|
375 |
+
label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must)
|
376 |
+
label_ec_colmap_ma = np.ma.masked_not_equal( # noqa: F841
|
377 |
+
label_ec_colmap, color_must
|
378 |
+
)
|
379 |
+
|
380 |
+
# Combined pred and label edges
|
381 |
+
combined_ec = pred_ec_colmap + label_ec_colmap
|
382 |
+
combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0))
|
383 |
+
combined_ec_img = combined_ec_ma.mask * img + combined_ec
|
384 |
+
|
385 |
+
# Pred
|
386 |
+
pred_colmap = pred.astype(float)
|
387 |
+
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
|
388 |
+
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
|
389 |
+
|
390 |
+
# Must
|
391 |
+
label_colmap = label.astype(float)
|
392 |
+
label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
|
393 |
+
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must)
|
394 |
+
|
395 |
+
# TP
|
396 |
+
tp_map = imread(
|
397 |
+
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem)
|
398 |
+
)
|
399 |
+
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
|
400 |
+
tp_map_colmap = tp_map.astype(float)
|
401 |
+
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
|
402 |
+
tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp)
|
403 |
+
|
404 |
+
# Combination
|
405 |
+
comb_pred = (
|
406 |
+
(pred_colmap_ma.mask ^ tp_map_colmap_ma.mask)
|
407 |
+
& tp_map_colmap_ma.mask
|
408 |
+
& combined_ec_ma.mask
|
409 |
+
) * pred_colmap
|
410 |
+
comb_label = (
|
411 |
+
(label_colmap_ma.mask ^ pred_colmap_ma.mask)
|
412 |
+
& pred_colmap_ma.mask
|
413 |
+
& combined_ec_ma.mask
|
414 |
+
) * label_colmap
|
415 |
+
comb_tp = combined_ec_ma.mask * tp_map_colmap.copy()
|
416 |
+
combined = comb_tp + comb_label + comb_pred
|
417 |
+
combined_ma = np.ma.masked_equal(combined, (0, 0, 0))
|
418 |
+
combined_ma = combined_ma.mask * combined_ec_img + combined_ma
|
419 |
+
|
420 |
+
ax.imshow(combined_ec_img, alpha=1)
|
421 |
+
ax.imshow(combined_ma, alpha=0.5)
|
422 |
+
ax.axis("off")
|
423 |
+
|
424 |
+
# Plot lines
|
425 |
+
idx_sort_x = np.argsort(pred_ec_coord[:, 1])
|
426 |
+
offset = 100
|
427 |
+
for idx in range(offset, pred_ec_coord.shape[0], offset):
|
428 |
+
y0, x0 = pred_ec_coord[idx_sort_x[idx], :]
|
429 |
+
argmin = np.argmin(dist_mat[idx_sort_x[idx]])
|
430 |
+
y1, x1 = label_ec_coord[argmin, :]
|
431 |
+
ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5)
|
432 |
+
|
433 |
+
# Annotation
|
434 |
+
ax.annotate(
|
435 |
+
xy=(0.05, 0.95),
|
436 |
+
xycoords="axes fraction",
|
437 |
+
xytext=(0.05, 0.95),
|
438 |
+
textcoords="axes fraction",
|
439 |
+
text=img_id,
|
440 |
+
fontsize="x-large",
|
441 |
+
verticalalignment="top",
|
442 |
+
color="white",
|
443 |
+
)
|
444 |
+
# Legend
|
445 |
+
if do_legend:
|
446 |
+
handles = []
|
447 |
+
lw = 1.0
|
448 |
+
handles.append(
|
449 |
+
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
|
450 |
+
)
|
451 |
+
handles.append(
|
452 |
+
mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66)
|
453 |
+
)
|
454 |
+
handles.append(
|
455 |
+
mpatches.Patch(
|
456 |
+
facecolor=color_must, label="Must-be-flooded", linewidth=lw, alpha=0.66
|
457 |
+
)
|
458 |
+
)
|
459 |
+
labels = ["TP", "Prediction", "Must-be-flooded"]
|
460 |
+
ax.legend(
|
461 |
+
handles=handles,
|
462 |
+
labels=labels,
|
463 |
+
bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
|
464 |
+
ncol=3,
|
465 |
+
mode="expand",
|
466 |
+
fontsize="xx-small",
|
467 |
+
frameon=False,
|
468 |
+
)
|
469 |
+
|
470 |
+
|
471 |
+
def plot_images_metric(axes, metric, img_filename, img_id, do_legend):
|
472 |
+
|
473 |
+
# Read images
|
474 |
+
img_path = imgs_orig_path / img_filename
|
475 |
+
label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
|
476 |
+
img, label = crop_and_resize(img_path, label_path)
|
477 |
+
img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0
|
478 |
+
pred = imread(
|
479 |
+
model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem)
|
480 |
+
)
|
481 |
+
|
482 |
+
# Label
|
483 |
+
plot_labels(axes[0], img, label, img_id, do_legend)
|
484 |
+
|
485 |
+
# Prediction
|
486 |
+
plot_pred(axes[1], img, pred, img_id, do_legend)
|
487 |
+
|
488 |
+
# Correct / incorrect
|
489 |
+
if metric in ["error", "f05"]:
|
490 |
+
plot_correct_incorrect(axes[2], img_filename, img, label, img_id, do_legend)
|
491 |
+
# Edge coherence
|
492 |
+
elif metric == "edge_coherence":
|
493 |
+
plot_edge_coherence(axes[2], img, label, pred, img_id, do_legend)
|
494 |
+
else:
|
495 |
+
raise ValueError
|
496 |
+
|
497 |
+
|
498 |
+
def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images):
|
499 |
+
|
500 |
+
sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax)
|
501 |
+
|
502 |
+
# Set X-label
|
503 |
+
ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium")
|
504 |
+
|
505 |
+
# Set Y-label
|
506 |
+
ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium")
|
507 |
+
|
508 |
+
# Change spines
|
509 |
+
sns.despine(ax=ax, left=True, bottom=True)
|
510 |
+
|
511 |
+
annotate_scatterplot(ax, dict_images, x_metric, y_metric)
|
512 |
+
|
513 |
+
|
514 |
+
def scatterplot_metrics(ax, df, dict_images):
|
515 |
+
|
516 |
+
sns.scatterplot(data=df, x="error", y="f05", hue="edge_coherence", ax=ax)
|
517 |
+
|
518 |
+
# Set X-label
|
519 |
+
ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium")
|
520 |
+
|
521 |
+
# Set Y-label
|
522 |
+
ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium")
|
523 |
+
|
524 |
+
annotate_scatterplot(ax, dict_images, "error", "f05")
|
525 |
+
|
526 |
+
# Change spines
|
527 |
+
sns.despine(ax=ax, left=True, bottom=True)
|
528 |
+
|
529 |
+
# Set XY limits
|
530 |
+
xlim = ax.get_xlim()
|
531 |
+
ylim = ax.get_ylim()
|
532 |
+
ax.set_xlim([0.0, xlim[1]])
|
533 |
+
ax.set_ylim([ylim[0], 1.0])
|
534 |
+
|
535 |
+
|
536 |
+
def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1):
|
537 |
+
xlim = ax.get_xlim()
|
538 |
+
ylim = ax.get_ylim()
|
539 |
+
x_len = xlim[1] - xlim[0]
|
540 |
+
y_len = ylim[1] - ylim[0]
|
541 |
+
x_th = xlim[1] - x_len / 2.0
|
542 |
+
y_th = ylim[1] - y_len / 2.0
|
543 |
+
for text, d in dict_images.items():
|
544 |
+
x = d[x_metric]
|
545 |
+
y = d[y_metric]
|
546 |
+
x_text = x + x_len * offset if x < x_th else x - x_len * offset
|
547 |
+
y_text = y + y_len * offset if y < y_th else y - y_len * offset
|
548 |
+
ax.annotate(
|
549 |
+
xy=(x, y),
|
550 |
+
xycoords="data",
|
551 |
+
xytext=(x_text, y_text),
|
552 |
+
textcoords="data",
|
553 |
+
text=text,
|
554 |
+
arrowprops=dict(facecolor="black", shrink=0.05),
|
555 |
+
fontsize="medium",
|
556 |
+
color="black",
|
557 |
+
)
|
558 |
+
|
559 |
+
|
560 |
+
if __name__ == "__main__":
|
561 |
+
# -----------------------------
|
562 |
+
# ----- Parse arguments -----
|
563 |
+
# -----------------------------
|
564 |
+
args = parsed_args()
|
565 |
+
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
|
566 |
+
|
567 |
+
# Determine output dir
|
568 |
+
if args.output_dir is None:
|
569 |
+
output_dir = Path(os.environ["SLURM_TMPDIR"])
|
570 |
+
else:
|
571 |
+
output_dir = Path(args.output_dir)
|
572 |
+
if not output_dir.exists():
|
573 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
574 |
+
|
575 |
+
# Store args
|
576 |
+
output_yml = output_dir / "labels.yml"
|
577 |
+
with open(output_yml, "w") as f:
|
578 |
+
yaml.dump(vars(args), f)
|
579 |
+
|
580 |
+
# Data dirs
|
581 |
+
imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
|
582 |
+
labels_path = Path(args.masker_test_set_dir) / "labels"
|
583 |
+
|
584 |
+
# Read CSV
|
585 |
+
df = pd.read_csv(args.input_csv, index_col="model_img_idx")
|
586 |
+
|
587 |
+
# Select best model
|
588 |
+
df = df.loc[df.model_feats == args.best_model]
|
589 |
+
v_key, model_dir = df.model.unique()[0].split("/")
|
590 |
+
model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir
|
591 |
+
|
592 |
+
# Set up plot
|
593 |
+
sns.reset_orig()
|
594 |
+
sns.set(style="whitegrid")
|
595 |
+
plt.rcParams.update({"font.family": "serif"})
|
596 |
+
plt.rcParams.update(
|
597 |
+
{
|
598 |
+
"font.serif": [
|
599 |
+
"Computer Modern Roman",
|
600 |
+
"Times New Roman",
|
601 |
+
"Utopia",
|
602 |
+
"New Century Schoolbook",
|
603 |
+
"Century Schoolbook L",
|
604 |
+
"ITC Bookman",
|
605 |
+
"Bookman",
|
606 |
+
"Times",
|
607 |
+
"Palatino",
|
608 |
+
"Charter",
|
609 |
+
"serif" "Bitstream Vera Serif",
|
610 |
+
"DejaVu Serif",
|
611 |
+
]
|
612 |
+
}
|
613 |
+
)
|
614 |
+
|
615 |
+
if args.seed:
|
616 |
+
np.random.seed(args.seed)
|
617 |
+
img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
618 |
+
dict_images = {}
|
619 |
+
idx = 0
|
620 |
+
for metric in metrics:
|
621 |
+
|
622 |
+
fig, axes = plt.subplots(nrows=2, ncols=3, dpi=200, figsize=(18, 12))
|
623 |
+
|
624 |
+
# Select best
|
625 |
+
if metric == "error":
|
626 |
+
ascending = True
|
627 |
+
else:
|
628 |
+
ascending = False
|
629 |
+
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
|
630 |
+
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
|
631 |
+
img_id = img_ids[idx]
|
632 |
+
dict_images.update({img_id: srs_sel})
|
633 |
+
|
634 |
+
# Read images
|
635 |
+
img_filename = srs_sel.filename
|
636 |
+
|
637 |
+
if not args.no_images:
|
638 |
+
axes_row = axes[0, :]
|
639 |
+
plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=True)
|
640 |
+
|
641 |
+
idx += 1
|
642 |
+
|
643 |
+
# Select worst
|
644 |
+
if metric == "error":
|
645 |
+
ascending = False
|
646 |
+
else:
|
647 |
+
ascending = True
|
648 |
+
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
|
649 |
+
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
|
650 |
+
img_id = img_ids[idx]
|
651 |
+
dict_images.update({img_id: srs_sel})
|
652 |
+
|
653 |
+
# Read images
|
654 |
+
img_filename = srs_sel.filename
|
655 |
+
|
656 |
+
if not args.no_images:
|
657 |
+
axes_row = axes[1, :]
|
658 |
+
plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=False)
|
659 |
+
|
660 |
+
idx += 1
|
661 |
+
|
662 |
+
# Save figure
|
663 |
+
output_fig = output_dir / "{}.png".format(metric)
|
664 |
+
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
|
665 |
+
|
666 |
+
fig = plt.figure(dpi=200)
|
667 |
+
scatterplot_metrics(fig.gca(), df, dict_images)
|
668 |
+
|
669 |
+
# fig, axes = plt.subplots(nrows=1, ncols=3, dpi=200, figsize=(18, 5))
|
670 |
+
#
|
671 |
+
# scatterplot_metrics_pair(axes[0], df, 'error', 'f05', dict_images)
|
672 |
+
# scatterplot_metrics_pair(axes[1], df, 'error', 'edge_coherence', dict_images)
|
673 |
+
# scatterplot_metrics_pair(axes[2], df, 'f05', 'edge_coherence', dict_images)
|
674 |
+
#
|
675 |
+
output_fig = output_dir / "scatterplots.png"
|
676 |
+
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
|
figures/metrics_onefig.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This scripts plots examples of the images that get best and worse metrics
|
3 |
+
"""
|
4 |
+
print("Imports...", end="")
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import matplotlib.patches as mpatches
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import seaborn as sns
|
15 |
+
import yaml
|
16 |
+
from imageio import imread
|
17 |
+
from matplotlib.gridspec import GridSpec
|
18 |
+
from skimage.color import rgba2rgb
|
19 |
+
from sklearn.metrics.pairwise import euclidean_distances
|
20 |
+
|
21 |
+
sys.path.append("../")
|
22 |
+
|
23 |
+
from climategan.data import encode_mask_label
|
24 |
+
from climategan.eval_metrics import edges_coherence_std_min
|
25 |
+
from eval_masker import crop_and_resize
|
26 |
+
|
27 |
+
# -----------------------
|
28 |
+
# ----- Constants -----
|
29 |
+
# -----------------------
|
30 |
+
|
31 |
+
# Metrics
|
32 |
+
metrics = ["error", "f05", "edge_coherence"]
|
33 |
+
|
34 |
+
dict_metrics = {
|
35 |
+
"names": {
|
36 |
+
"tpr": "TPR, Recall, Sensitivity",
|
37 |
+
"tnr": "TNR, Specificity, Selectivity",
|
38 |
+
"fpr": "FPR",
|
39 |
+
"fpt": "False positives relative to image size",
|
40 |
+
"fnr": "FNR, Miss rate",
|
41 |
+
"fnt": "False negatives relative to image size",
|
42 |
+
"mpr": "May positive rate (MPR)",
|
43 |
+
"mnr": "May negative rate (MNR)",
|
44 |
+
"accuracy": "Accuracy (ignoring may)",
|
45 |
+
"error": "Error",
|
46 |
+
"f05": "F05 score",
|
47 |
+
"precision": "Precision",
|
48 |
+
"edge_coherence": "Edge coherence",
|
49 |
+
"accuracy_must_may": "Accuracy (ignoring cannot)",
|
50 |
+
},
|
51 |
+
"key_metrics": ["error", "f05", "edge_coherence"],
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
# Colors
|
56 |
+
colorblind_palette = sns.color_palette("colorblind")
|
57 |
+
color_cannot = colorblind_palette[1]
|
58 |
+
color_must = colorblind_palette[2]
|
59 |
+
color_may = colorblind_palette[7]
|
60 |
+
color_pred = colorblind_palette[4]
|
61 |
+
|
62 |
+
icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
|
63 |
+
color_tp = icefire[0]
|
64 |
+
color_tn = icefire[1]
|
65 |
+
color_fp = icefire[4]
|
66 |
+
color_fn = icefire[3]
|
67 |
+
|
68 |
+
|
69 |
+
def parsed_args():
|
70 |
+
"""
|
71 |
+
Parse and returns command-line args
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
argparse.Namespace: the parsed arguments
|
75 |
+
"""
|
76 |
+
parser = ArgumentParser()
|
77 |
+
parser.add_argument(
|
78 |
+
"--input_csv",
|
79 |
+
default="ablations_metrics_20210311.csv",
|
80 |
+
type=str,
|
81 |
+
help="CSV containing the results of the ablation study",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--output_dir",
|
85 |
+
default=None,
|
86 |
+
type=str,
|
87 |
+
help="Output directory",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--models_log_path",
|
91 |
+
default=None,
|
92 |
+
type=str,
|
93 |
+
help="Path containing the log files of the models",
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
"--masker_test_set_dir",
|
97 |
+
default=None,
|
98 |
+
type=str,
|
99 |
+
help="Directory containing the test images",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--best_model",
|
103 |
+
default="dada, msd_spade, pseudo",
|
104 |
+
type=str,
|
105 |
+
help="The string identifier of the best model",
|
106 |
+
)
|
107 |
+
parser.add_argument(
|
108 |
+
"--dpi",
|
109 |
+
default=200,
|
110 |
+
type=int,
|
111 |
+
help="DPI for the output images",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--alpha",
|
115 |
+
default=0.5,
|
116 |
+
type=float,
|
117 |
+
help="Transparency of labels shade",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--percentile",
|
121 |
+
default=0.05,
|
122 |
+
type=float,
|
123 |
+
help="Transparency of labels shade",
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--seed",
|
127 |
+
default=None,
|
128 |
+
type=int,
|
129 |
+
help="Bootstrap random seed, for reproducibility",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--no_images",
|
133 |
+
action="store_true",
|
134 |
+
default=False,
|
135 |
+
help="Do not generate images",
|
136 |
+
)
|
137 |
+
|
138 |
+
return parser.parse_args()
|
139 |
+
|
140 |
+
|
141 |
+
def map_color(arr, input_color, output_color, rtol=1e-09):
|
142 |
+
"""
|
143 |
+
Maps one color to another
|
144 |
+
"""
|
145 |
+
input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
|
146 |
+
output = arr.copy()
|
147 |
+
output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
|
148 |
+
return output
|
149 |
+
|
150 |
+
|
151 |
+
def plot_labels(ax, img, label, img_id, n_, add_title, do_legend):
|
152 |
+
label_colmap = label.astype(float)
|
153 |
+
label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
|
154 |
+
label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
|
155 |
+
label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
|
156 |
+
|
157 |
+
ax.imshow(img)
|
158 |
+
ax.imshow(label_colmap, alpha=0.5)
|
159 |
+
ax.axis("off")
|
160 |
+
|
161 |
+
if n_ in [1, 3, 5]:
|
162 |
+
color_ = "green"
|
163 |
+
else:
|
164 |
+
color_ = "red"
|
165 |
+
|
166 |
+
ax.text(
|
167 |
+
-0.15,
|
168 |
+
0.5,
|
169 |
+
img_id,
|
170 |
+
color=color_,
|
171 |
+
fontweight="roman",
|
172 |
+
fontsize="x-large",
|
173 |
+
horizontalalignment="left",
|
174 |
+
verticalalignment="center",
|
175 |
+
transform=ax.transAxes,
|
176 |
+
)
|
177 |
+
|
178 |
+
if add_title:
|
179 |
+
ax.set_title("Labels", rotation=0, fontsize="x-large")
|
180 |
+
|
181 |
+
|
182 |
+
def plot_pred(ax, img, pred, img_id, add_title, do_legend):
|
183 |
+
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
|
184 |
+
|
185 |
+
pred_colmap = pred.astype(float)
|
186 |
+
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
|
187 |
+
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
|
188 |
+
pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma
|
189 |
+
|
190 |
+
ax.imshow(img)
|
191 |
+
ax.imshow(pred_colmap_ma, alpha=0.5)
|
192 |
+
ax.axis("off")
|
193 |
+
|
194 |
+
if add_title:
|
195 |
+
ax.set_title("Prediction", rotation=0, fontsize="x-large")
|
196 |
+
|
197 |
+
|
198 |
+
def plot_correct_incorrect(
|
199 |
+
ax, img_filename, img, metric, label, img_id, n_, add_title, do_legend
|
200 |
+
):
|
201 |
+
# FP
|
202 |
+
fp_map = imread(
|
203 |
+
model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem)
|
204 |
+
)
|
205 |
+
fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3))
|
206 |
+
|
207 |
+
fp_map_colmap = fp_map.astype(float)
|
208 |
+
fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp)
|
209 |
+
|
210 |
+
# FN
|
211 |
+
fn_map = imread(
|
212 |
+
model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem)
|
213 |
+
)
|
214 |
+
fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3))
|
215 |
+
|
216 |
+
fn_map_colmap = fn_map.astype(float)
|
217 |
+
fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn)
|
218 |
+
|
219 |
+
# TP
|
220 |
+
tp_map = imread(
|
221 |
+
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem)
|
222 |
+
)
|
223 |
+
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
|
224 |
+
|
225 |
+
tp_map_colmap = tp_map.astype(float)
|
226 |
+
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
|
227 |
+
|
228 |
+
# TN
|
229 |
+
tn_map = imread(
|
230 |
+
model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem)
|
231 |
+
)
|
232 |
+
tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3))
|
233 |
+
|
234 |
+
tn_map_colmap = tn_map.astype(float)
|
235 |
+
tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn)
|
236 |
+
|
237 |
+
label_colmap = label.astype(float)
|
238 |
+
label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
|
239 |
+
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may)
|
240 |
+
label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma
|
241 |
+
|
242 |
+
# Combine masks
|
243 |
+
maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap
|
244 |
+
maps_ma = np.ma.masked_equal(maps, (0, 0, 0))
|
245 |
+
maps_ma = maps_ma.mask * img + maps_ma
|
246 |
+
|
247 |
+
ax.imshow(img)
|
248 |
+
ax.imshow(label_colmap_ma, alpha=0.5)
|
249 |
+
ax.imshow(maps_ma, alpha=0.5)
|
250 |
+
ax.axis("off")
|
251 |
+
|
252 |
+
if add_title:
|
253 |
+
ax.set_title("Metric", rotation=0, fontsize="x-large")
|
254 |
+
|
255 |
+
|
256 |
+
def plot_edge_coherence(ax, img, metric, label, pred, img_id, n_, add_title, do_legend):
|
257 |
+
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
|
258 |
+
|
259 |
+
ec, pred_ec, label_ec = edges_coherence_std_min(
|
260 |
+
np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood"))
|
261 |
+
)
|
262 |
+
|
263 |
+
##################
|
264 |
+
# Edge distances #
|
265 |
+
##################
|
266 |
+
|
267 |
+
# Location of edges
|
268 |
+
pred_ec_coord = np.argwhere(pred_ec > 0)
|
269 |
+
label_ec_coord = np.argwhere(label_ec > 0)
|
270 |
+
|
271 |
+
# Normalized pairwise distances between pred and label
|
272 |
+
dist_mat = np.divide(
|
273 |
+
euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0]
|
274 |
+
)
|
275 |
+
|
276 |
+
# Standard deviation of the minimum distance from pred to label
|
277 |
+
min_dist = np.min(dist_mat, axis=1) # noqa: F841
|
278 |
+
|
279 |
+
#############
|
280 |
+
# Make plot #
|
281 |
+
#############
|
282 |
+
|
283 |
+
pred_ec = np.tile(
|
284 |
+
np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
|
285 |
+
)
|
286 |
+
pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred)
|
287 |
+
pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) # noqa: F841
|
288 |
+
|
289 |
+
label_ec = np.tile(
|
290 |
+
np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
|
291 |
+
)
|
292 |
+
label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must)
|
293 |
+
label_ec_colmap_ma = np.ma.masked_not_equal( # noqa: F841
|
294 |
+
label_ec_colmap, color_must
|
295 |
+
)
|
296 |
+
|
297 |
+
# Combined pred and label edges
|
298 |
+
combined_ec = pred_ec_colmap + label_ec_colmap
|
299 |
+
combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0))
|
300 |
+
combined_ec_img = combined_ec_ma.mask * img + combined_ec
|
301 |
+
|
302 |
+
# Pred
|
303 |
+
pred_colmap = pred.astype(float)
|
304 |
+
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
|
305 |
+
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
|
306 |
+
|
307 |
+
# Must
|
308 |
+
label_colmap = label.astype(float)
|
309 |
+
label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
|
310 |
+
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must)
|
311 |
+
|
312 |
+
# TP
|
313 |
+
tp_map = imread(
|
314 |
+
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem)
|
315 |
+
)
|
316 |
+
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
|
317 |
+
tp_map_colmap = tp_map.astype(float)
|
318 |
+
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
|
319 |
+
tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp)
|
320 |
+
|
321 |
+
# Combination
|
322 |
+
comb_pred = (
|
323 |
+
(pred_colmap_ma.mask ^ tp_map_colmap_ma.mask)
|
324 |
+
& tp_map_colmap_ma.mask
|
325 |
+
& combined_ec_ma.mask
|
326 |
+
) * pred_colmap
|
327 |
+
comb_label = (
|
328 |
+
(label_colmap_ma.mask ^ pred_colmap_ma.mask)
|
329 |
+
& pred_colmap_ma.mask
|
330 |
+
& combined_ec_ma.mask
|
331 |
+
) * label_colmap
|
332 |
+
comb_tp = combined_ec_ma.mask * tp_map_colmap.copy()
|
333 |
+
combined = comb_tp + comb_label + comb_pred
|
334 |
+
combined_ma = np.ma.masked_equal(combined, (0, 0, 0))
|
335 |
+
combined_ma = combined_ma.mask * combined_ec_img + combined_ma
|
336 |
+
|
337 |
+
ax.imshow(combined_ec_img, alpha=1)
|
338 |
+
ax.imshow(combined_ma, alpha=0.5)
|
339 |
+
ax.axis("off")
|
340 |
+
|
341 |
+
# Plot lines
|
342 |
+
idx_sort_x = np.argsort(pred_ec_coord[:, 1])
|
343 |
+
offset = 100
|
344 |
+
for idx in range(offset, pred_ec_coord.shape[0], offset):
|
345 |
+
y0, x0 = pred_ec_coord[idx_sort_x[idx], :]
|
346 |
+
argmin = np.argmin(dist_mat[idx_sort_x[idx]])
|
347 |
+
y1, x1 = label_ec_coord[argmin, :]
|
348 |
+
ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5)
|
349 |
+
|
350 |
+
if add_title:
|
351 |
+
ax.set_title("Metric", rotation=0, fontsize="x-large")
|
352 |
+
|
353 |
+
|
354 |
+
def plot_images_metric(
|
355 |
+
axes, metric, img_filename, img_id, n_, srs_sel, add_title, do_legend
|
356 |
+
):
|
357 |
+
|
358 |
+
# Read images
|
359 |
+
img_path = imgs_orig_path / img_filename
|
360 |
+
label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
|
361 |
+
img, label = crop_and_resize(img_path, label_path)
|
362 |
+
img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0
|
363 |
+
|
364 |
+
pred = imread(
|
365 |
+
model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem)
|
366 |
+
)
|
367 |
+
|
368 |
+
# Label
|
369 |
+
plot_labels(axes[0], img, label, img_id, n_, add_title, do_legend)
|
370 |
+
|
371 |
+
# Prediction
|
372 |
+
plot_pred(axes[1], img, pred, img_id, add_title, do_legend)
|
373 |
+
|
374 |
+
# Correct / incorrect
|
375 |
+
if metric in ["error", "f05"]:
|
376 |
+
plot_correct_incorrect(
|
377 |
+
axes[2],
|
378 |
+
img_filename,
|
379 |
+
img,
|
380 |
+
metric,
|
381 |
+
label,
|
382 |
+
img_id,
|
383 |
+
n_,
|
384 |
+
add_title,
|
385 |
+
do_legend=False,
|
386 |
+
)
|
387 |
+
handles = []
|
388 |
+
lw = 1.0
|
389 |
+
handles.append(
|
390 |
+
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
|
391 |
+
)
|
392 |
+
handles.append(
|
393 |
+
mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66)
|
394 |
+
)
|
395 |
+
handles.append(
|
396 |
+
mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66)
|
397 |
+
)
|
398 |
+
handles.append(
|
399 |
+
mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66)
|
400 |
+
)
|
401 |
+
handles.append(
|
402 |
+
mpatches.Patch(
|
403 |
+
facecolor=color_may,
|
404 |
+
label="May-be-flooded",
|
405 |
+
linewidth=lw,
|
406 |
+
alpha=0.66,
|
407 |
+
)
|
408 |
+
)
|
409 |
+
labels = ["TP", "TN", "FP", "FN", "May-be-flooded"]
|
410 |
+
if metric == "error":
|
411 |
+
if n_ in [1, 3, 5]:
|
412 |
+
title = "Low error rate"
|
413 |
+
else:
|
414 |
+
title = "High error rate"
|
415 |
+
else:
|
416 |
+
if n_ in [1, 3, 5]:
|
417 |
+
title = "High F05 score"
|
418 |
+
else:
|
419 |
+
title = "Low F05 score"
|
420 |
+
# Edge coherence
|
421 |
+
elif metric == "edge_coherence":
|
422 |
+
plot_edge_coherence(
|
423 |
+
axes[2], img, metric, label, pred, img_id, n_, add_title, do_legend=False
|
424 |
+
)
|
425 |
+
handles = []
|
426 |
+
lw = 1.0
|
427 |
+
handles.append(
|
428 |
+
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
|
429 |
+
)
|
430 |
+
handles.append(
|
431 |
+
mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66)
|
432 |
+
)
|
433 |
+
handles.append(
|
434 |
+
mpatches.Patch(
|
435 |
+
facecolor=color_must,
|
436 |
+
label="Must-be-flooded",
|
437 |
+
linewidth=lw,
|
438 |
+
alpha=0.66,
|
439 |
+
)
|
440 |
+
)
|
441 |
+
labels = ["TP", "Prediction", "Must-be-flooded"]
|
442 |
+
if n_ in [1, 3, 5]:
|
443 |
+
title = "High edge coherence"
|
444 |
+
else:
|
445 |
+
title = "Low edge coherence"
|
446 |
+
|
447 |
+
else:
|
448 |
+
raise ValueError
|
449 |
+
|
450 |
+
labels_values_title = "Error: {:.4f} \nFO5: {:.4f} \nEdge coherence: {:.4f}".format(
|
451 |
+
srs_sel.error, srs_sel.f05, srs_sel.edge_coherence
|
452 |
+
)
|
453 |
+
|
454 |
+
plot_legend(axes[3], img, handles, labels, labels_values_title, title)
|
455 |
+
|
456 |
+
|
457 |
+
def plot_legend(ax, img, handles, labels, labels_values_title, title):
|
458 |
+
img_ = np.zeros_like(img, dtype=np.uint8)
|
459 |
+
img_.fill(255)
|
460 |
+
ax.imshow(img_)
|
461 |
+
ax.axis("off")
|
462 |
+
|
463 |
+
leg1 = ax.legend(
|
464 |
+
handles=handles,
|
465 |
+
labels=labels,
|
466 |
+
title=title,
|
467 |
+
title_fontsize="medium",
|
468 |
+
labelspacing=0.6,
|
469 |
+
loc="upper left",
|
470 |
+
fontsize="x-small",
|
471 |
+
frameon=False,
|
472 |
+
)
|
473 |
+
leg1._legend_box.align = "left"
|
474 |
+
|
475 |
+
leg2 = ax.legend(
|
476 |
+
title=labels_values_title,
|
477 |
+
title_fontsize="small",
|
478 |
+
loc="lower left",
|
479 |
+
frameon=False,
|
480 |
+
)
|
481 |
+
leg2._legend_box.align = "left"
|
482 |
+
|
483 |
+
ax.add_artist(leg1)
|
484 |
+
|
485 |
+
|
486 |
+
def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images):
|
487 |
+
|
488 |
+
sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax)
|
489 |
+
|
490 |
+
# Set X-label
|
491 |
+
ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium")
|
492 |
+
|
493 |
+
# Set Y-label
|
494 |
+
ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium")
|
495 |
+
|
496 |
+
# Change spines
|
497 |
+
sns.despine(ax=ax, left=True, bottom=True)
|
498 |
+
|
499 |
+
annotate_scatterplot(ax, dict_images, x_metric, y_metric)
|
500 |
+
|
501 |
+
|
502 |
+
def scatterplot_metrics(ax, df, df_all, dict_images, plot_all=False):
|
503 |
+
|
504 |
+
# Other
|
505 |
+
if plot_all:
|
506 |
+
sns.scatterplot(
|
507 |
+
data=df_all.loc[df_all.ground == True],
|
508 |
+
x="error", y="f05", hue="edge_coherence", ax=ax,
|
509 |
+
marker='+', alpha=0.25)
|
510 |
+
sns.scatterplot(
|
511 |
+
data=df_all.loc[df_all.instagan == True],
|
512 |
+
x="error", y="f05", hue="edge_coherence", ax=ax,
|
513 |
+
marker='x', alpha=0.25)
|
514 |
+
sns.scatterplot(
|
515 |
+
data=df_all.loc[(df_all.instagan == False) & (df_all.instagan == False) &
|
516 |
+
(df_all.model_feats != args.best_model)],
|
517 |
+
x="error", y="f05", hue="edge_coherence", ax=ax,
|
518 |
+
marker='s', alpha=0.25)
|
519 |
+
|
520 |
+
# Best model
|
521 |
+
cmap_ = sns.cubehelix_palette(as_cmap=True)
|
522 |
+
sns.scatterplot(
|
523 |
+
data=df, x="error", y="f05", hue="edge_coherence", ax=ax, palette=cmap_
|
524 |
+
)
|
525 |
+
|
526 |
+
norm = plt.Normalize(df["edge_coherence"].min(), df["edge_coherence"].max())
|
527 |
+
sm = plt.cm.ScalarMappable(cmap=cmap_, norm=norm)
|
528 |
+
sm.set_array([])
|
529 |
+
|
530 |
+
# Remove the legend and add a colorbar
|
531 |
+
ax.get_legend().remove()
|
532 |
+
ax_cbar = ax.figure.colorbar(sm)
|
533 |
+
ax_cbar.set_label("Edge coherence", labelpad=8)
|
534 |
+
|
535 |
+
# Set X-label
|
536 |
+
ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium")
|
537 |
+
|
538 |
+
# Set Y-label
|
539 |
+
ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium")
|
540 |
+
|
541 |
+
annotate_scatterplot(ax, dict_images, "error", "f05")
|
542 |
+
|
543 |
+
# Change spines
|
544 |
+
sns.despine(ax=ax, left=True, bottom=True)
|
545 |
+
|
546 |
+
# Set XY limits
|
547 |
+
xlim = ax.get_xlim()
|
548 |
+
ylim = ax.get_ylim()
|
549 |
+
ax.set_xlim([0.0, xlim[1]])
|
550 |
+
ax.set_ylim([ylim[0], 1.0])
|
551 |
+
|
552 |
+
|
553 |
+
def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1):
|
554 |
+
xlim = ax.get_xlim()
|
555 |
+
ylim = ax.get_ylim()
|
556 |
+
x_len = xlim[1] - xlim[0]
|
557 |
+
y_len = ylim[1] - ylim[0]
|
558 |
+
x_th = xlim[1] - x_len / 2.0
|
559 |
+
y_th = ylim[1] - y_len / 2.0
|
560 |
+
for text, d in dict_images.items():
|
561 |
+
if text in ["B", "D", "F"]:
|
562 |
+
x = d[x_metric]
|
563 |
+
y = d[y_metric]
|
564 |
+
|
565 |
+
x_text = x + x_len * offset if x < x_th else x - x_len * offset
|
566 |
+
y_text = y + y_len * offset if y < y_th else y - y_len * offset
|
567 |
+
|
568 |
+
ax.annotate(
|
569 |
+
xy=(x, y),
|
570 |
+
xycoords="data",
|
571 |
+
xytext=(x_text, y_text),
|
572 |
+
textcoords="data",
|
573 |
+
text=text,
|
574 |
+
arrowprops=dict(facecolor="black", shrink=0.05),
|
575 |
+
fontsize="medium",
|
576 |
+
color="black",
|
577 |
+
)
|
578 |
+
elif text == "A":
|
579 |
+
x = (
|
580 |
+
dict_images["A"][x_metric]
|
581 |
+
+ dict_images["C"][x_metric]
|
582 |
+
+ dict_images["E"][x_metric]
|
583 |
+
) / 3
|
584 |
+
y = (
|
585 |
+
dict_images["A"][y_metric]
|
586 |
+
+ dict_images["C"][y_metric]
|
587 |
+
+ dict_images["E"][y_metric]
|
588 |
+
) / 3
|
589 |
+
|
590 |
+
x_text = x + x_len * 2 * offset if x < x_th else x - x_len * 2 * offset
|
591 |
+
y_text = (
|
592 |
+
y + y_len * 0.45 * offset if y < y_th else y - y_len * 0.45 * offset
|
593 |
+
)
|
594 |
+
|
595 |
+
ax.annotate(
|
596 |
+
xy=(x, y),
|
597 |
+
xycoords="data",
|
598 |
+
xytext=(x_text, y_text),
|
599 |
+
textcoords="data",
|
600 |
+
text="A, C, E",
|
601 |
+
arrowprops=dict(facecolor="black", shrink=0.05),
|
602 |
+
fontsize="medium",
|
603 |
+
color="black",
|
604 |
+
)
|
605 |
+
|
606 |
+
|
607 |
+
if __name__ == "__main__":
|
608 |
+
# -----------------------------
|
609 |
+
# ----- Parse arguments -----
|
610 |
+
# -----------------------------
|
611 |
+
args = parsed_args()
|
612 |
+
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
|
613 |
+
|
614 |
+
# Determine output dir
|
615 |
+
if args.output_dir is None:
|
616 |
+
output_dir = Path(os.environ["SLURM_TMPDIR"])
|
617 |
+
else:
|
618 |
+
output_dir = Path(args.output_dir)
|
619 |
+
if not output_dir.exists():
|
620 |
+
output_dir.mkdir(parents=True, exist_ok=False)
|
621 |
+
|
622 |
+
# Store args
|
623 |
+
output_yml = output_dir / "labels.yml"
|
624 |
+
with open(output_yml, "w") as f:
|
625 |
+
yaml.dump(vars(args), f)
|
626 |
+
|
627 |
+
# Data dirs
|
628 |
+
imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
|
629 |
+
labels_path = Path(args.masker_test_set_dir) / "labels"
|
630 |
+
|
631 |
+
# Read CSV
|
632 |
+
df_all = pd.read_csv(args.input_csv, index_col="model_img_idx")
|
633 |
+
|
634 |
+
# Select best model
|
635 |
+
df = df_all.loc[df_all.model_feats == args.best_model]
|
636 |
+
v_key, model_dir = df.model.unique()[0].split("/")
|
637 |
+
model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir
|
638 |
+
|
639 |
+
# Set up plot
|
640 |
+
sns.reset_orig()
|
641 |
+
sns.set(style="whitegrid")
|
642 |
+
plt.rcParams.update({"font.family": "serif"})
|
643 |
+
plt.rcParams.update(
|
644 |
+
{
|
645 |
+
"font.serif": [
|
646 |
+
"Computer Modern Roman",
|
647 |
+
"Times New Roman",
|
648 |
+
"Utopia",
|
649 |
+
"New Century Schoolbook",
|
650 |
+
"Century Schoolbook L",
|
651 |
+
"ITC Bookman",
|
652 |
+
"Bookman",
|
653 |
+
"Times",
|
654 |
+
"Palatino",
|
655 |
+
"Charter",
|
656 |
+
"serif" "Bitstream Vera Serif",
|
657 |
+
"DejaVu Serif",
|
658 |
+
]
|
659 |
+
}
|
660 |
+
)
|
661 |
+
|
662 |
+
if args.seed:
|
663 |
+
np.random.seed(args.seed)
|
664 |
+
img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
665 |
+
dict_images = {}
|
666 |
+
idx = 0
|
667 |
+
|
668 |
+
# Define grid of subplots
|
669 |
+
grid_vmargin = 0.03 # Extent of the vertical margin between metric grids
|
670 |
+
ax_hspace = 0.04 # Extent of the vertical space between axes of same grid
|
671 |
+
ax_wspace = 0.05 # Extent of the horizontal space between axes of same grid
|
672 |
+
n_grids = len(metrics)
|
673 |
+
n_cols = 4
|
674 |
+
n_rows = 2
|
675 |
+
h_grid = (1.0 / n_grids) - ((n_grids - 1) * grid_vmargin) / n_grids
|
676 |
+
|
677 |
+
fig1 = plt.figure(dpi=200, figsize=(11, 13))
|
678 |
+
|
679 |
+
n_ = 0
|
680 |
+
add_title = False
|
681 |
+
for metric_id, metric in enumerate(metrics):
|
682 |
+
|
683 |
+
# Create grid
|
684 |
+
top_grid = 1.0 - metric_id * h_grid - metric_id * grid_vmargin
|
685 |
+
bottom_grid = top_grid - h_grid
|
686 |
+
gridspec = GridSpec(
|
687 |
+
n_rows,
|
688 |
+
n_cols,
|
689 |
+
wspace=ax_wspace,
|
690 |
+
hspace=ax_hspace,
|
691 |
+
bottom=bottom_grid,
|
692 |
+
top=top_grid,
|
693 |
+
)
|
694 |
+
|
695 |
+
# Select best
|
696 |
+
if metric == "error":
|
697 |
+
ascending = True
|
698 |
+
else:
|
699 |
+
ascending = False
|
700 |
+
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
|
701 |
+
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
|
702 |
+
img_id = img_ids[idx]
|
703 |
+
dict_images.update({img_id: srs_sel})
|
704 |
+
# Read images
|
705 |
+
img_filename = srs_sel.filename
|
706 |
+
|
707 |
+
axes_row = [fig1.add_subplot(gridspec[0, c]) for c in range(n_cols)]
|
708 |
+
if not args.no_images:
|
709 |
+
n_ += 1
|
710 |
+
if metric_id == 0:
|
711 |
+
add_title = True
|
712 |
+
plot_images_metric(
|
713 |
+
axes_row,
|
714 |
+
metric,
|
715 |
+
img_filename,
|
716 |
+
img_id,
|
717 |
+
n_,
|
718 |
+
srs_sel,
|
719 |
+
add_title=add_title,
|
720 |
+
do_legend=False,
|
721 |
+
)
|
722 |
+
add_title = False
|
723 |
+
|
724 |
+
idx += 1
|
725 |
+
print("1 more row done.")
|
726 |
+
# Select worst
|
727 |
+
if metric == "error":
|
728 |
+
ascending = False
|
729 |
+
else:
|
730 |
+
ascending = True
|
731 |
+
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
|
732 |
+
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
|
733 |
+
img_id = img_ids[idx]
|
734 |
+
dict_images.update({img_id: srs_sel})
|
735 |
+
# Read images
|
736 |
+
img_filename = srs_sel.filename
|
737 |
+
|
738 |
+
axes_row = [fig1.add_subplot(gridspec[1, c]) for c in range(n_cols)]
|
739 |
+
if not args.no_images:
|
740 |
+
n_ += 1
|
741 |
+
plot_images_metric(
|
742 |
+
axes_row,
|
743 |
+
metric,
|
744 |
+
img_filename,
|
745 |
+
img_id,
|
746 |
+
n_,
|
747 |
+
srs_sel,
|
748 |
+
add_title=add_title,
|
749 |
+
do_legend=False,
|
750 |
+
)
|
751 |
+
|
752 |
+
idx += 1
|
753 |
+
print("1 more row done.")
|
754 |
+
|
755 |
+
output_fig = output_dir / "all_metrics.png"
|
756 |
+
|
757 |
+
fig1.tight_layout() # (pad=1.5) #
|
758 |
+
fig1.savefig(output_fig, dpi=fig1.dpi, bbox_inches="tight")
|
759 |
+
|
760 |
+
# Scatter plot
|
761 |
+
fig2 = plt.figure(dpi=200)
|
762 |
+
|
763 |
+
scatterplot_metrics(fig2.gca(), df, df_all, dict_images)
|
764 |
+
|
765 |
+
# fig2, axes = plt.subplots(nrows=1, ncols=3, dpi=200, figsize=(18, 5))
|
766 |
+
#
|
767 |
+
# scatterplot_metrics_pair(axes[0], df, "error", "f05", dict_images)
|
768 |
+
# scatterplot_metrics_pair(axes[1], df, "error", "edge_coherence", dict_images)
|
769 |
+
# scatterplot_metrics_pair(axes[2], df, "f05", "edge_coherence", dict_images)
|
770 |
+
|
771 |
+
output_fig = output_dir / "scatterplots.png"
|
772 |
+
fig2.savefig(output_fig, dpi=fig2.dpi, bbox_inches="tight")
|
requirements-3.8.2.txt
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
addict==2.4.0
|
2 |
+
APScheduler==3.7.0
|
3 |
+
attrs==21.2.0
|
4 |
+
backcall==0.2.0
|
5 |
+
Brotli==1.0.9
|
6 |
+
certifi==2021.5.30
|
7 |
+
charset-normalizer==2.0.4
|
8 |
+
click==8.0.1
|
9 |
+
codecarbon==1.2.0
|
10 |
+
comet-ml==3.15.3
|
11 |
+
configobj==5.0.6
|
12 |
+
cycler==0.10.0
|
13 |
+
dash==2.0.0
|
14 |
+
dash-bootstrap-components==0.13.0
|
15 |
+
dash-core-components==2.0.0
|
16 |
+
dash-html-components==2.0.0
|
17 |
+
dash-table==5.0.0
|
18 |
+
dataclasses==0.6
|
19 |
+
decorator==5.0.9
|
20 |
+
dulwich==0.20.25
|
21 |
+
everett==2.0.1
|
22 |
+
filelock==3.0.12
|
23 |
+
fire==0.4.0
|
24 |
+
Flask==2.0.1
|
25 |
+
Flask-Compress==1.10.1
|
26 |
+
future==0.18.2
|
27 |
+
gdown==3.13.0
|
28 |
+
hydra-core==0.11.3
|
29 |
+
idna==3.2
|
30 |
+
imageio==2.9.0
|
31 |
+
ipython==7.27.0
|
32 |
+
itsdangerous==2.0.1
|
33 |
+
jedi==0.18.0
|
34 |
+
Jinja2==3.0.1
|
35 |
+
joblib==1.0.1
|
36 |
+
jsonschema==3.2.0
|
37 |
+
kiwisolver==1.3.2
|
38 |
+
kornia==0.5.10
|
39 |
+
MarkupSafe==2.0.1
|
40 |
+
matplotlib==3.4.3
|
41 |
+
matplotlib-inline==0.1.2
|
42 |
+
networkx==2.6.2
|
43 |
+
numpy==1.21.2
|
44 |
+
nvidia-ml-py3==7.352.0
|
45 |
+
omegaconf==1.4.1
|
46 |
+
opencv-python==4.5.3.56
|
47 |
+
packaging==21.0
|
48 |
+
pandas==1.3.2
|
49 |
+
parso==0.8.2
|
50 |
+
pexpect==4.8.0
|
51 |
+
pickleshare==0.7.5
|
52 |
+
Pillow==8.3.2
|
53 |
+
plotly==5.3.1
|
54 |
+
prompt-toolkit==3.0.20
|
55 |
+
ptyprocess==0.7.0
|
56 |
+
py-cpuinfo==8.0.0
|
57 |
+
Pygments==2.10.0
|
58 |
+
pynvml==11.0.0
|
59 |
+
pyparsing==2.4.7
|
60 |
+
pyrsistent==0.18.0
|
61 |
+
PySocks==1.7.1
|
62 |
+
python-dateutil==2.8.2
|
63 |
+
pytorch-ranger==0.1.1
|
64 |
+
pytz==2021.1
|
65 |
+
PyWavelets==1.1.1
|
66 |
+
PyYAML==5.4.1
|
67 |
+
requests==2.26.0
|
68 |
+
requests-toolbelt==0.9.1
|
69 |
+
scikit-image==0.18.3
|
70 |
+
scikit-learn==0.24.2
|
71 |
+
scipy==1.7.1
|
72 |
+
seaborn==0.11.2
|
73 |
+
semantic-version==2.8.5
|
74 |
+
six==1.16.0
|
75 |
+
tenacity==8.0.1
|
76 |
+
termcolor==1.1.0
|
77 |
+
threadpoolctl==2.2.0
|
78 |
+
tifffile==2021.8.30
|
79 |
+
torch==1.7.1
|
80 |
+
torch-optimizer==0.1.0
|
81 |
+
torchvision==0.8.2
|
82 |
+
tqdm==4.62.2
|
83 |
+
traitlets==5.1.0
|
84 |
+
typing-extensions==3.10.0.2
|
85 |
+
tzlocal==2.1
|
86 |
+
urllib3==1.26.6
|
87 |
+
wcwidth==0.2.5
|
88 |
+
websocket-client==1.2.1
|
89 |
+
Werkzeug==2.0.1
|
90 |
+
wrapt==1.12.1
|
91 |
+
wurlitzer==3.0.2
|
requirements-any.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
addict
|
2 |
+
codecarbon
|
3 |
+
comet_ml
|
4 |
+
hydra-core==0.11.3
|
5 |
+
kornia
|
6 |
+
omegaconf==1.4.1
|
7 |
+
matplotlib
|
8 |
+
numpy
|
9 |
+
opencv-python
|
10 |
+
packaging
|
11 |
+
pandas
|
12 |
+
PyYAML
|
13 |
+
scikit-image
|
14 |
+
scikit-learn
|
15 |
+
scipy
|
16 |
+
seaborn
|
17 |
+
torch==1.7.0
|
18 |
+
torch-optimizer
|
19 |
+
torchvision==0.8.1
|
20 |
+
tqdm
|
sbatch.py
ADDED
@@ -0,0 +1,933 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import itertools
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
from collections import defaultdict
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import yaml
|
12 |
+
|
13 |
+
|
14 |
+
def flatten_conf(conf, to={}, parents=[]):
|
15 |
+
"""
|
16 |
+
Flattens a configuration dict: nested dictionaries are flattened
|
17 |
+
as key1.key2.key3 = value
|
18 |
+
|
19 |
+
conf.yaml:
|
20 |
+
```yaml
|
21 |
+
a: 1
|
22 |
+
b:
|
23 |
+
c: 2
|
24 |
+
d:
|
25 |
+
e: 3
|
26 |
+
g:
|
27 |
+
sample: sequential
|
28 |
+
from: [4, 5]
|
29 |
+
```
|
30 |
+
|
31 |
+
Is flattened to
|
32 |
+
|
33 |
+
{
|
34 |
+
"a": 1,
|
35 |
+
"b.c": 2,
|
36 |
+
"b.d.e": 3,
|
37 |
+
"b.g": {
|
38 |
+
"sample": "sequential",
|
39 |
+
"from": [4, 5]
|
40 |
+
}
|
41 |
+
}
|
42 |
+
|
43 |
+
Does not affect sampling dicts.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
conf (dict): the configuration to flatten
|
47 |
+
new (dict, optional): the target flatenned dict. Defaults to {}.
|
48 |
+
parents (list, optional): a final value's list of parents. Defaults to [].
|
49 |
+
"""
|
50 |
+
for k, v in conf.items():
|
51 |
+
if isinstance(v, dict) and "sample" not in v:
|
52 |
+
flatten_conf(v, to, parents + [k])
|
53 |
+
else:
|
54 |
+
new_k = ".".join([str(p) for p in parents + [k]])
|
55 |
+
to[new_k] = v
|
56 |
+
|
57 |
+
|
58 |
+
def env_to_path(path):
|
59 |
+
"""Transorms an environment variable mention in a json
|
60 |
+
into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds
|
61 |
+
|
62 |
+
Args:
|
63 |
+
path (str): path potentially containing the env variable
|
64 |
+
|
65 |
+
"""
|
66 |
+
path_elements = path.split("/")
|
67 |
+
new_path = []
|
68 |
+
for el in path_elements:
|
69 |
+
if "$" in el:
|
70 |
+
new_path.append(os.environ[el.replace("$", "")])
|
71 |
+
else:
|
72 |
+
new_path.append(el)
|
73 |
+
return "/".join(new_path)
|
74 |
+
|
75 |
+
|
76 |
+
class C:
|
77 |
+
HEADER = "\033[95m"
|
78 |
+
OKBLUE = "\033[94m"
|
79 |
+
OKGREEN = "\033[92m"
|
80 |
+
WARNING = "\033[93m"
|
81 |
+
FAIL = "\033[91m"
|
82 |
+
ENDC = "\033[0m"
|
83 |
+
BOLD = "\033[1m"
|
84 |
+
UNDERLINE = "\033[4m"
|
85 |
+
ITALIC = "\33[3m"
|
86 |
+
BEIGE = "\33[36m"
|
87 |
+
|
88 |
+
|
89 |
+
def escape_path(path):
|
90 |
+
p = str(path)
|
91 |
+
return p.replace(" ", "\ ").replace("(", "\(").replace(")", "\)") # noqa: W605
|
92 |
+
|
93 |
+
|
94 |
+
def warn(*args, **kwargs):
|
95 |
+
print("{}{}{}".format(C.WARNING, " ".join(args), C.ENDC), **kwargs)
|
96 |
+
|
97 |
+
|
98 |
+
def parse_jobID(command_output):
|
99 |
+
"""
|
100 |
+
get job id from successful sbatch command output like
|
101 |
+
`Submitted batch job 599583`
|
102 |
+
|
103 |
+
Args:
|
104 |
+
command_output (str): sbatch command's output
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
int: the slurm job's ID
|
108 |
+
"""
|
109 |
+
command_output = command_output.strip()
|
110 |
+
if isinstance(command_output, str):
|
111 |
+
if "Submitted batch job" in command_output:
|
112 |
+
return int(command_output.split()[-1])
|
113 |
+
|
114 |
+
return -1
|
115 |
+
|
116 |
+
|
117 |
+
def now():
|
118 |
+
return str(datetime.datetime.now()).replace(" ", "_")
|
119 |
+
|
120 |
+
|
121 |
+
def cols():
|
122 |
+
try:
|
123 |
+
col = os.get_terminal_size().columns
|
124 |
+
except Exception:
|
125 |
+
col = 50
|
126 |
+
return col
|
127 |
+
|
128 |
+
|
129 |
+
def print_box(txt):
|
130 |
+
if not txt:
|
131 |
+
txt = "{}{}ERROR ⇪{}".format(C.BOLD, C.FAIL, C.ENDC)
|
132 |
+
lt = 7
|
133 |
+
else:
|
134 |
+
lt = len(txt)
|
135 |
+
nlt = lt + 12
|
136 |
+
txt = "|" + " " * 5 + txt + " " * 5 + "|"
|
137 |
+
line = "-" * nlt
|
138 |
+
empty = "|" + " " * (nlt - 2) + "|"
|
139 |
+
print(line)
|
140 |
+
print(empty)
|
141 |
+
print(txt)
|
142 |
+
print(empty)
|
143 |
+
print(line)
|
144 |
+
|
145 |
+
|
146 |
+
def print_header(idx):
|
147 |
+
b = C.BOLD
|
148 |
+
bl = C.OKBLUE
|
149 |
+
e = C.ENDC
|
150 |
+
char = "≡"
|
151 |
+
c = cols()
|
152 |
+
|
153 |
+
txt = " " * 20
|
154 |
+
txt += f"{b}{bl}Run {idx}{e}"
|
155 |
+
txt += " " * 20
|
156 |
+
ln = len(txt) - len(b) - len(bl) - len(e)
|
157 |
+
t = int(np.floor((c - ln) / 2))
|
158 |
+
tt = int(np.ceil((c - ln) / 2))
|
159 |
+
|
160 |
+
print(char * c)
|
161 |
+
print(char * t + " " * ln + char * tt)
|
162 |
+
print(char * t + txt + char * tt)
|
163 |
+
print(char * t + " " * ln + char * tt)
|
164 |
+
print(char * c)
|
165 |
+
|
166 |
+
|
167 |
+
def print_footer():
|
168 |
+
c = cols()
|
169 |
+
char = "﹎"
|
170 |
+
print()
|
171 |
+
print(char * (c // len(char)))
|
172 |
+
print()
|
173 |
+
print(" " * (c // 2) + "•" + " " * (c - c // 2 - 1))
|
174 |
+
print()
|
175 |
+
|
176 |
+
|
177 |
+
def extend_summary(summary, tmp_train_args_dict, tmp_template_dict, exclude=[]):
|
178 |
+
exclude = set(exclude)
|
179 |
+
if summary is None:
|
180 |
+
summary = defaultdict(list)
|
181 |
+
for k, v in tmp_template_dict.items():
|
182 |
+
if k not in exclude:
|
183 |
+
summary[k].append(v)
|
184 |
+
for k, v in tmp_train_args_dict.items():
|
185 |
+
if k not in exclude:
|
186 |
+
if isinstance(v, list):
|
187 |
+
v = str(v)
|
188 |
+
summary[k].append(v)
|
189 |
+
return summary
|
190 |
+
|
191 |
+
|
192 |
+
def search_summary_table(summary, summary_dir=None):
|
193 |
+
# filter out constant values
|
194 |
+
summary = {k: v for k, v in summary.items() if len(set(v)) > 1}
|
195 |
+
|
196 |
+
# if everything is constant: no summary
|
197 |
+
if not summary:
|
198 |
+
return None, None
|
199 |
+
|
200 |
+
# find number of searches
|
201 |
+
n_searches = len(list(summary.values())[0])
|
202 |
+
|
203 |
+
# print section title
|
204 |
+
print(
|
205 |
+
"{}{}{}Varying values across {} experiments:{}\n".format(
|
206 |
+
C.OKBLUE,
|
207 |
+
C.BOLD,
|
208 |
+
C.UNDERLINE,
|
209 |
+
n_searches,
|
210 |
+
C.ENDC,
|
211 |
+
)
|
212 |
+
)
|
213 |
+
|
214 |
+
# first column holds the Exp. number
|
215 |
+
first_col = {
|
216 |
+
"len": 8, # length of a column, to split columns according to terminal width
|
217 |
+
"str": ["| Exp. |", "|:----:|"]
|
218 |
+
+ [
|
219 |
+
"| {0:^{1}} |".format(i, 4) for i in range(n_searches)
|
220 |
+
], # list of values to print
|
221 |
+
}
|
222 |
+
|
223 |
+
print_columns = [[first_col]]
|
224 |
+
file_columns = [first_col]
|
225 |
+
for k in sorted(summary.keys()):
|
226 |
+
v = summary[k]
|
227 |
+
col_title = f" {k} |"
|
228 |
+
col_blank_line = f":{'-' * len(k)}-|"
|
229 |
+
col_values = [
|
230 |
+
" {0:{1}} |".format(
|
231 |
+
crop_string(
|
232 |
+
str(crop_float(v[idx], min([5, len(k) - 2]))), len(k)
|
233 |
+
), # crop floats and long strings
|
234 |
+
len(k),
|
235 |
+
)
|
236 |
+
for idx in range(len(v))
|
237 |
+
]
|
238 |
+
|
239 |
+
# create column object
|
240 |
+
col = {"len": len(k) + 3, "str": [col_title, col_blank_line] + col_values}
|
241 |
+
|
242 |
+
# if adding a new column would overflow the terminal and mess up printing, start
|
243 |
+
# new set of columns
|
244 |
+
if sum(c["len"] for c in print_columns[-1]) + col["len"] >= cols():
|
245 |
+
print_columns.append([first_col])
|
246 |
+
|
247 |
+
# store current column to latest group of columns
|
248 |
+
print_columns[-1].append(col)
|
249 |
+
file_columns.append(col)
|
250 |
+
|
251 |
+
print_table = ""
|
252 |
+
# print each column group individually
|
253 |
+
for colgroup in print_columns:
|
254 |
+
# print columns line by line
|
255 |
+
for i in range(n_searches + 2):
|
256 |
+
# get value of column for current line i
|
257 |
+
for col in colgroup:
|
258 |
+
print_table += col["str"][i]
|
259 |
+
# next line for current columns
|
260 |
+
print_table += "\n"
|
261 |
+
|
262 |
+
# new lines for new column group
|
263 |
+
print_table += "\n"
|
264 |
+
|
265 |
+
file_table = ""
|
266 |
+
for i in range(n_searches + 2):
|
267 |
+
# get value of column for current line i
|
268 |
+
for col in file_columns:
|
269 |
+
file_table += col["str"][i]
|
270 |
+
# next line for current columns
|
271 |
+
file_table += "\n"
|
272 |
+
|
273 |
+
summary_path = None
|
274 |
+
if summary_dir is not None:
|
275 |
+
summary_path = summary_dir / (now() + ".md")
|
276 |
+
with summary_path.open("w") as f:
|
277 |
+
f.write(file_table.strip())
|
278 |
+
|
279 |
+
return print_table, summary_path
|
280 |
+
|
281 |
+
|
282 |
+
def clean_arg(v):
|
283 |
+
"""
|
284 |
+
chain cleaning function
|
285 |
+
|
286 |
+
Args:
|
287 |
+
v (any): arg to pass to train.py
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
str: parsed value to string
|
291 |
+
"""
|
292 |
+
return stringify_list(crop_float(quote_string(resolve_env(v))))
|
293 |
+
|
294 |
+
|
295 |
+
def resolve_env(v):
|
296 |
+
"""
|
297 |
+
resolve env variables in paths
|
298 |
+
|
299 |
+
Args:
|
300 |
+
v (any): arg to pass to train.py
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
str: try and resolve an env variable
|
304 |
+
"""
|
305 |
+
if isinstance(v, str):
|
306 |
+
try:
|
307 |
+
if "$" in v:
|
308 |
+
if "/" in v:
|
309 |
+
v = env_to_path(v)
|
310 |
+
else:
|
311 |
+
_v = os.environ.get(v)
|
312 |
+
if _v is not None:
|
313 |
+
v = _v
|
314 |
+
except Exception:
|
315 |
+
pass
|
316 |
+
return v
|
317 |
+
|
318 |
+
|
319 |
+
def stringify_list(v):
|
320 |
+
"""
|
321 |
+
Stringify list (with double quotes) so that it can be passed a an argument
|
322 |
+
to train.py's hydra command-line parsing
|
323 |
+
|
324 |
+
Args:
|
325 |
+
v (any): value to clean
|
326 |
+
|
327 |
+
Returns:
|
328 |
+
any: type of v, str if v was a list
|
329 |
+
"""
|
330 |
+
if isinstance(v, list):
|
331 |
+
return '"{}"'.format(str(v).replace('"', "'"))
|
332 |
+
if isinstance(v, str):
|
333 |
+
if v.startswith("[") and v.endswith("]"):
|
334 |
+
return f'"{v}"'
|
335 |
+
return v
|
336 |
+
|
337 |
+
|
338 |
+
def quote_string(v):
|
339 |
+
"""
|
340 |
+
Add double quotes around string if it contains a " " or an =
|
341 |
+
|
342 |
+
Args:
|
343 |
+
v (any): value to clean
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
any: type of v, quoted if v is a string with " " or =
|
347 |
+
"""
|
348 |
+
if isinstance(v, str):
|
349 |
+
if " " in v or "=" in v:
|
350 |
+
return f'"{v}"'
|
351 |
+
return v
|
352 |
+
|
353 |
+
|
354 |
+
def crop_float(v, k=5):
|
355 |
+
"""
|
356 |
+
If v is a float, crop precision to 5 digits and return v as a str
|
357 |
+
|
358 |
+
Args:
|
359 |
+
v (any): value to crop if float
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
any: cropped float as str if v is a float, original v otherwise
|
363 |
+
"""
|
364 |
+
if isinstance(v, float):
|
365 |
+
return "{0:.{1}g}".format(v, k)
|
366 |
+
return v
|
367 |
+
|
368 |
+
|
369 |
+
def compute_n_search(conf):
|
370 |
+
"""
|
371 |
+
Compute the number of searchs to do if using -1 as n_search and using
|
372 |
+
cartesian or sequential search
|
373 |
+
|
374 |
+
Args:
|
375 |
+
conf (dict): experimental configuration
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
int: size of the cartesian product or length of longest sequential field
|
379 |
+
"""
|
380 |
+
samples = defaultdict(list)
|
381 |
+
for k, v in conf.items():
|
382 |
+
if not isinstance(v, dict) or "sample" not in v:
|
383 |
+
continue
|
384 |
+
samples[v["sample"]].append(v)
|
385 |
+
|
386 |
+
totals = []
|
387 |
+
|
388 |
+
if "cartesian" in samples:
|
389 |
+
total = 1
|
390 |
+
for s in samples["cartesian"]:
|
391 |
+
total *= len(s["from"])
|
392 |
+
totals.append(total)
|
393 |
+
if "sequential" in samples:
|
394 |
+
total = max(map(len, [s["from"] for s in samples["sequential"]]))
|
395 |
+
totals.append(total)
|
396 |
+
|
397 |
+
if totals:
|
398 |
+
return max(totals)
|
399 |
+
|
400 |
+
raise ValueError(
|
401 |
+
"Used n_search=-1 without any field being 'cartesian' or 'sequential'"
|
402 |
+
)
|
403 |
+
|
404 |
+
|
405 |
+
def crop_string(s, k=10):
|
406 |
+
if len(s) <= k:
|
407 |
+
return s
|
408 |
+
else:
|
409 |
+
return s[: k - 2] + ".."
|
410 |
+
|
411 |
+
|
412 |
+
def sample_param(sample_dict):
|
413 |
+
"""sample a value (hyperparameter) from the instruction in the
|
414 |
+
sample dict:
|
415 |
+
{
|
416 |
+
"sample": "range | list",
|
417 |
+
"from": [min, max, step] | [v0, v1, v2 etc.]
|
418 |
+
}
|
419 |
+
if range, as np.arange is used, "from" MUST be a list, but may contain
|
420 |
+
only 1 (=min) or 2 (min and max) values, not necessarily 3
|
421 |
+
|
422 |
+
Args:
|
423 |
+
sample_dict (dict): instructions to sample a value
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
scalar: sampled value
|
427 |
+
"""
|
428 |
+
if not isinstance(sample_dict, dict) or "sample" not in sample_dict:
|
429 |
+
return sample_dict
|
430 |
+
|
431 |
+
if sample_dict["sample"] == "cartesian":
|
432 |
+
assert isinstance(
|
433 |
+
sample_dict["from"], list
|
434 |
+
), "{}'s `from` field MUST be a list, found {}".format(
|
435 |
+
sample_dict["sample"], sample_dict["from"]
|
436 |
+
)
|
437 |
+
return "__cartesian__"
|
438 |
+
|
439 |
+
if sample_dict["sample"] == "sequential":
|
440 |
+
assert isinstance(
|
441 |
+
sample_dict["from"], list
|
442 |
+
), "{}'s `from` field MUST be a list, found {}".format(
|
443 |
+
sample_dict["sample"], sample_dict["from"]
|
444 |
+
)
|
445 |
+
return "__sequential__"
|
446 |
+
|
447 |
+
if sample_dict["sample"] == "range":
|
448 |
+
return np.random.choice(np.arange(*sample_dict["from"]))
|
449 |
+
|
450 |
+
if sample_dict["sample"] == "list":
|
451 |
+
return np.random.choice(sample_dict["from"])
|
452 |
+
|
453 |
+
if sample_dict["sample"] == "uniform":
|
454 |
+
return np.random.uniform(*sample_dict["from"])
|
455 |
+
|
456 |
+
raise ValueError("Unknown sample type in dict " + str(sample_dict))
|
457 |
+
|
458 |
+
|
459 |
+
def sample_sequentials(sequential_keys, exp, idx):
|
460 |
+
"""
|
461 |
+
Samples sequentially from the "from" values specified in each key of the
|
462 |
+
experimental configuration which have sample == "sequential"
|
463 |
+
Unlike `cartesian` sampling, `sequential` sampling iterates *independently*
|
464 |
+
over each keys
|
465 |
+
|
466 |
+
Args:
|
467 |
+
sequential_keys (list): keys to be sampled sequentially
|
468 |
+
exp (dict): experimental config
|
469 |
+
idx (int): index of the current sample
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
conf: sampled dict
|
473 |
+
"""
|
474 |
+
conf = {}
|
475 |
+
for k in sequential_keys:
|
476 |
+
v = exp[k]["from"]
|
477 |
+
conf[k] = v[idx % len(v)]
|
478 |
+
return conf
|
479 |
+
|
480 |
+
|
481 |
+
def sample_cartesians(cartesian_keys, exp, idx):
|
482 |
+
"""
|
483 |
+
Returns the `idx`th item in the cartesian product of all cartesian keys to
|
484 |
+
be sampled.
|
485 |
+
|
486 |
+
Args:
|
487 |
+
cartesian_keys (list): keys in the experimental configuration that are to
|
488 |
+
be used in the full cartesian product
|
489 |
+
exp (dict): experimental configuration
|
490 |
+
idx (int): index of the current sample
|
491 |
+
|
492 |
+
Returns:
|
493 |
+
dict: sampled point in the cartesian space (with keys = cartesian_keys)
|
494 |
+
"""
|
495 |
+
conf = {}
|
496 |
+
cartesian_values = [exp[key]["from"] for key in cartesian_keys]
|
497 |
+
product = list(itertools.product(*cartesian_values))
|
498 |
+
for k, v in zip(cartesian_keys, product[idx % len(product)]):
|
499 |
+
conf[k] = v
|
500 |
+
return conf
|
501 |
+
|
502 |
+
|
503 |
+
def resolve(hp_conf, nb):
|
504 |
+
"""
|
505 |
+
Samples parameters parametrized in `exp`: should be a dict with
|
506 |
+
values which fit `sample_params(dic)`'s API
|
507 |
+
|
508 |
+
Args:
|
509 |
+
exp (dict): experiment's parametrization
|
510 |
+
nb (int): number of experiments to sample
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
dict: sampled configuration
|
514 |
+
"""
|
515 |
+
if nb == -1:
|
516 |
+
nb = compute_n_search(hp_conf)
|
517 |
+
|
518 |
+
confs = []
|
519 |
+
for idx in range(nb):
|
520 |
+
conf = {}
|
521 |
+
cartesians = []
|
522 |
+
sequentials = []
|
523 |
+
for k, v in hp_conf.items():
|
524 |
+
candidate = sample_param(v)
|
525 |
+
if candidate == "__cartesian__":
|
526 |
+
cartesians.append(k)
|
527 |
+
elif candidate == "__sequential__":
|
528 |
+
sequentials.append(k)
|
529 |
+
else:
|
530 |
+
conf[k] = candidate
|
531 |
+
if sequentials:
|
532 |
+
conf.update(sample_sequentials(sequentials, hp_conf, idx))
|
533 |
+
if cartesians:
|
534 |
+
conf.update(sample_cartesians(cartesians, hp_conf, idx))
|
535 |
+
confs.append(conf)
|
536 |
+
return confs
|
537 |
+
|
538 |
+
|
539 |
+
def get_template_params(template):
|
540 |
+
"""
|
541 |
+
extract args in template str as {arg}
|
542 |
+
|
543 |
+
Args:
|
544 |
+
template (str): sbatch template string
|
545 |
+
|
546 |
+
Returns:
|
547 |
+
list(str): Args required to format the template string
|
548 |
+
"""
|
549 |
+
return map(
|
550 |
+
lambda s: s.replace("{", "").replace("}", ""),
|
551 |
+
re.findall("\{.*?\}", template), # noqa: W605
|
552 |
+
)
|
553 |
+
|
554 |
+
|
555 |
+
def read_exp_conf(name):
|
556 |
+
"""
|
557 |
+
Read hp search configuration from shared/experiment/
|
558 |
+
specified with or without the .yaml extension
|
559 |
+
|
560 |
+
Args:
|
561 |
+
name (str): name of the template to find in shared/experiment/
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
Tuple(Path, dict): file path and loaded dict
|
565 |
+
"""
|
566 |
+
if ".yaml" not in name:
|
567 |
+
name += ".yaml"
|
568 |
+
paths = []
|
569 |
+
dirs = ["shared", "config"]
|
570 |
+
for d in dirs:
|
571 |
+
path = Path(__file__).parent / d / "experiment" / name
|
572 |
+
if path.exists():
|
573 |
+
paths.append(path.resolve())
|
574 |
+
|
575 |
+
if len(paths) == 0:
|
576 |
+
failed = [Path(__file__).parent / d / "experiment" for d in dirs]
|
577 |
+
s = "Could not find search config {} in :\n".format(name)
|
578 |
+
for fd in failed:
|
579 |
+
s += str(fd) + "\nAvailable:\n"
|
580 |
+
for ym in fd.glob("*.yaml"):
|
581 |
+
s += " " + ym.name + "\n"
|
582 |
+
raise ValueError(s)
|
583 |
+
|
584 |
+
if len(paths) == 2:
|
585 |
+
print(
|
586 |
+
"Warning: found 2 relevant files for search config:\n{}".format(
|
587 |
+
"\n".join(paths)
|
588 |
+
)
|
589 |
+
)
|
590 |
+
print("Using {}".format(paths[-1]))
|
591 |
+
|
592 |
+
with paths[-1].open("r") as f:
|
593 |
+
conf = yaml.safe_load(f)
|
594 |
+
|
595 |
+
flat_conf = {}
|
596 |
+
flatten_conf(conf, to=flat_conf)
|
597 |
+
|
598 |
+
return (paths[-1], flat_conf)
|
599 |
+
|
600 |
+
|
601 |
+
def read_template(name):
|
602 |
+
"""
|
603 |
+
Read template from shared/template/ specified with or without the .sh extension
|
604 |
+
|
605 |
+
Args:
|
606 |
+
name (str): name of the template to find in shared/template/
|
607 |
+
|
608 |
+
Returns:
|
609 |
+
str: file's content as 1 string
|
610 |
+
"""
|
611 |
+
if ".sh" not in name:
|
612 |
+
name += ".sh"
|
613 |
+
paths = []
|
614 |
+
dirs = ["shared", "config"]
|
615 |
+
for d in dirs:
|
616 |
+
path = Path(__file__).parent / d / "template" / name
|
617 |
+
if path.exists():
|
618 |
+
paths.append(path)
|
619 |
+
|
620 |
+
if len(paths) == 0:
|
621 |
+
failed = [Path(__file__).parent / d / "template" for d in dirs]
|
622 |
+
s = "Could not find template {} in :\n".format(name)
|
623 |
+
for fd in failed:
|
624 |
+
s += str(fd) + "\nAvailable:\n"
|
625 |
+
for ym in fd.glob("*.sh"):
|
626 |
+
s += " " + ym.name + "\n"
|
627 |
+
raise ValueError(s)
|
628 |
+
|
629 |
+
if len(paths) == 2:
|
630 |
+
print("Warning: found 2 relevant template files:\n{}".format("\n".join(paths)))
|
631 |
+
print("Using {}".format(paths[-1]))
|
632 |
+
|
633 |
+
with paths[-1].open("r") as f:
|
634 |
+
return f.read()
|
635 |
+
|
636 |
+
|
637 |
+
def is_sampled(key, conf):
|
638 |
+
"""
|
639 |
+
Is a key sampled or constant? Returns true if conf is empty
|
640 |
+
|
641 |
+
Args:
|
642 |
+
key (str): key to check
|
643 |
+
conf (dict): hyper parameter search configuration dict
|
644 |
+
|
645 |
+
Returns:
|
646 |
+
bool: key is sampled?
|
647 |
+
"""
|
648 |
+
return not conf or (
|
649 |
+
key in conf and isinstance(conf[key], dict) and "sample" in conf[key]
|
650 |
+
)
|
651 |
+
|
652 |
+
|
653 |
+
if __name__ == "__main__":
|
654 |
+
|
655 |
+
"""
|
656 |
+
Notes:
|
657 |
+
* Must provide template name as template=name
|
658 |
+
* `name`.sh should be in shared/template/
|
659 |
+
"""
|
660 |
+
|
661 |
+
# -------------------------------
|
662 |
+
# ----- Default Variables -----
|
663 |
+
# -------------------------------
|
664 |
+
|
665 |
+
args = sys.argv[1:]
|
666 |
+
command_output = ""
|
667 |
+
user = os.environ.get("USER")
|
668 |
+
home = os.environ.get("HOME")
|
669 |
+
exp_conf = {}
|
670 |
+
dev = False
|
671 |
+
escape = False
|
672 |
+
verbose = False
|
673 |
+
template_name = None
|
674 |
+
hp_exp_name = None
|
675 |
+
hp_search_nb = None
|
676 |
+
exp_path = None
|
677 |
+
resume = None
|
678 |
+
force_sbatchs = False
|
679 |
+
sbatch_base = Path(home) / "climategan_sbatchs"
|
680 |
+
summary_dir = Path(home) / "climategan_exp_summaries"
|
681 |
+
|
682 |
+
hp_search_private = set(["n_search", "template", "search", "summary_dir"])
|
683 |
+
|
684 |
+
sbatch_path = "hash"
|
685 |
+
|
686 |
+
# --------------------------
|
687 |
+
# ----- Sanity Check -----
|
688 |
+
# --------------------------
|
689 |
+
|
690 |
+
for arg in args:
|
691 |
+
if "=" not in arg or " = " in arg:
|
692 |
+
raise ValueError(
|
693 |
+
"Args should be passed as `key=value`. Received `{}`".format(arg)
|
694 |
+
)
|
695 |
+
|
696 |
+
# --------------------------------
|
697 |
+
# ----- Parse Command Line -----
|
698 |
+
# --------------------------------
|
699 |
+
|
700 |
+
args_dict = {arg.split("=")[0]: arg.split("=")[1] for arg in args}
|
701 |
+
|
702 |
+
assert "template" in args_dict, "Please specify template=xxx"
|
703 |
+
template = read_template(args_dict["template"])
|
704 |
+
template_dict = {k: None for k in get_template_params(template)}
|
705 |
+
|
706 |
+
train_args = []
|
707 |
+
for k, v in args_dict.items():
|
708 |
+
|
709 |
+
if k == "verbose":
|
710 |
+
if v != "0":
|
711 |
+
verbose = True
|
712 |
+
|
713 |
+
elif k == "sbatch_path":
|
714 |
+
sbatch_path = v
|
715 |
+
|
716 |
+
elif k == "sbatch_base":
|
717 |
+
sbatch_base = Path(v).resolve()
|
718 |
+
|
719 |
+
elif k == "force_sbatchs":
|
720 |
+
force_sbatchs = v.lower() == "true"
|
721 |
+
|
722 |
+
elif k == "dev":
|
723 |
+
if v.lower() != "false":
|
724 |
+
dev = True
|
725 |
+
|
726 |
+
elif k == "escape":
|
727 |
+
if v.lower() != "false":
|
728 |
+
escape = True
|
729 |
+
|
730 |
+
elif k == "template":
|
731 |
+
template_name = v
|
732 |
+
|
733 |
+
elif k == "exp":
|
734 |
+
hp_exp_name = v
|
735 |
+
|
736 |
+
elif k == "n_search":
|
737 |
+
hp_search_nb = int(v)
|
738 |
+
|
739 |
+
elif k == "resume":
|
740 |
+
resume = f'"{v}"'
|
741 |
+
template_dict[k] = f'"{v}"'
|
742 |
+
|
743 |
+
elif k == "summary_dir":
|
744 |
+
if v.lower() == "none":
|
745 |
+
summary_dir = None
|
746 |
+
else:
|
747 |
+
summary_dir = Path(v)
|
748 |
+
|
749 |
+
elif k in template_dict:
|
750 |
+
template_dict[k] = v
|
751 |
+
|
752 |
+
else:
|
753 |
+
train_args.append(f"{k}={v}")
|
754 |
+
|
755 |
+
# ------------------------------------
|
756 |
+
# ----- Load Experiment Config -----
|
757 |
+
# ------------------------------------
|
758 |
+
|
759 |
+
if hp_exp_name is not None:
|
760 |
+
exp_path, exp_conf = read_exp_conf(hp_exp_name)
|
761 |
+
if "n_search" in exp_conf and hp_search_nb is None:
|
762 |
+
hp_search_nb = exp_conf["n_search"]
|
763 |
+
|
764 |
+
assert (
|
765 |
+
hp_search_nb is not None
|
766 |
+
), "n_search should be specified in a yaml file or from the command line"
|
767 |
+
|
768 |
+
hps = resolve(exp_conf, hp_search_nb)
|
769 |
+
|
770 |
+
else:
|
771 |
+
hps = [None]
|
772 |
+
|
773 |
+
# ---------------------------------
|
774 |
+
# ----- Run All Experiments -----
|
775 |
+
# ---------------------------------
|
776 |
+
if summary_dir is not None:
|
777 |
+
summary_dir.mkdir(exist_ok=True, parents=True)
|
778 |
+
summary = None
|
779 |
+
|
780 |
+
for hp_idx, hp in enumerate(hps):
|
781 |
+
|
782 |
+
# copy shared values
|
783 |
+
tmp_template_dict = template_dict.copy()
|
784 |
+
tmp_train_args = train_args.copy()
|
785 |
+
tmp_train_args_dict = {
|
786 |
+
arg.split("=")[0]: arg.split("=")[1] for arg in tmp_train_args
|
787 |
+
}
|
788 |
+
print_header(hp_idx)
|
789 |
+
# override shared values with run-specific values for run hp_idx/n_search
|
790 |
+
if hp is not None:
|
791 |
+
for k, v in hp.items():
|
792 |
+
if k == "resume" and resume is None:
|
793 |
+
resume = f'"{v}"'
|
794 |
+
# hp-search params to ignore
|
795 |
+
if k in hp_search_private:
|
796 |
+
continue
|
797 |
+
|
798 |
+
if k == "codeloc":
|
799 |
+
v = escape_path(v)
|
800 |
+
|
801 |
+
if k == "output":
|
802 |
+
Path(v).parent.mkdir(parents=True, exist_ok=True)
|
803 |
+
|
804 |
+
# override template params depending on exp config
|
805 |
+
if k in tmp_template_dict:
|
806 |
+
if template_dict[k] is None or is_sampled(k, exp_conf):
|
807 |
+
tmp_template_dict[k] = v
|
808 |
+
# store sampled / specified params in current tmp_train_args_dict
|
809 |
+
else:
|
810 |
+
if k in tmp_train_args_dict:
|
811 |
+
if is_sampled(k, exp_conf):
|
812 |
+
# warn if key was specified from the command line
|
813 |
+
tv = tmp_train_args_dict[k]
|
814 |
+
warn(
|
815 |
+
"\nWarning: overriding sampled config-file arg",
|
816 |
+
"{} to command-line value {}\n".format(k, tv),
|
817 |
+
)
|
818 |
+
else:
|
819 |
+
tmp_train_args_dict[k] = v
|
820 |
+
|
821 |
+
# create sbatch file where required
|
822 |
+
tmp_sbatch_path = None
|
823 |
+
if sbatch_path == "hash":
|
824 |
+
tmp_sbatch_name = "" if hp_exp_name is None else hp_exp_name[:14] + "_"
|
825 |
+
tmp_sbatch_name += now() + ".sh"
|
826 |
+
tmp_sbatch_path = sbatch_base / tmp_sbatch_name
|
827 |
+
tmp_sbatch_path.parent.mkdir(parents=True, exist_ok=True)
|
828 |
+
tmp_train_args_dict["sbatch_file"] = str(tmp_sbatch_path)
|
829 |
+
tmp_train_args_dict["exp_file"] = str(exp_path)
|
830 |
+
else:
|
831 |
+
tmp_sbatch_path = Path(sbatch_path).resolve()
|
832 |
+
|
833 |
+
summary = extend_summary(
|
834 |
+
summary, tmp_train_args_dict, tmp_template_dict, exclude=["sbatch_file"]
|
835 |
+
)
|
836 |
+
|
837 |
+
# format train.py's args and crop floats' precision to 5 digits
|
838 |
+
tmp_template_dict["train_args"] = " ".join(
|
839 |
+
sorted(
|
840 |
+
[
|
841 |
+
"{}={}".format(k, clean_arg(v))
|
842 |
+
for k, v in tmp_train_args_dict.items()
|
843 |
+
]
|
844 |
+
)
|
845 |
+
)
|
846 |
+
|
847 |
+
if "resume.py" in template and resume is None:
|
848 |
+
raise ValueError("No `resume` value but using a resume.py template")
|
849 |
+
|
850 |
+
# format template with clean dict (replace None with "")
|
851 |
+
sbatch = template.format(
|
852 |
+
**{
|
853 |
+
k: v if v is not None else ""
|
854 |
+
for k, v in tmp_template_dict.items()
|
855 |
+
if k in template_dict
|
856 |
+
}
|
857 |
+
)
|
858 |
+
|
859 |
+
# --------------------------------------
|
860 |
+
# ----- Execute `sbatch` Command -----
|
861 |
+
# --------------------------------------
|
862 |
+
if not dev or force_sbatchs:
|
863 |
+
if tmp_sbatch_path.exists():
|
864 |
+
print(f"Warning: overwriting {sbatch_path}")
|
865 |
+
|
866 |
+
# write sbatch file
|
867 |
+
with open(tmp_sbatch_path, "w") as f:
|
868 |
+
f.write(sbatch)
|
869 |
+
|
870 |
+
if not dev:
|
871 |
+
# escape special characters such as " " from sbatch_path's parent dir
|
872 |
+
parent = str(tmp_sbatch_path.parent)
|
873 |
+
if escape:
|
874 |
+
parent = escape_path(parent)
|
875 |
+
|
876 |
+
# create command to execute in a subprocess
|
877 |
+
command = "sbatch {}".format(tmp_sbatch_path.name)
|
878 |
+
# execute sbatch command & store output
|
879 |
+
command_output = subprocess.run(
|
880 |
+
command.split(), stdout=subprocess.PIPE, cwd=parent
|
881 |
+
)
|
882 |
+
command_output = "\n" + command_output.stdout.decode("utf-8") + "\n"
|
883 |
+
|
884 |
+
print(f"Running from {parent}:")
|
885 |
+
print(f"$ {command}")
|
886 |
+
|
887 |
+
# ---------------------------------
|
888 |
+
# ----- Summarize Execution -----
|
889 |
+
# ---------------------------------
|
890 |
+
if verbose:
|
891 |
+
print(C.BEIGE + C.ITALIC, "\n" + sbatch + C.ENDC)
|
892 |
+
if not dev:
|
893 |
+
print_box(command_output.strip())
|
894 |
+
jobID = parse_jobID(command_output.strip())
|
895 |
+
summary["Slurm JOBID"].append(jobID)
|
896 |
+
|
897 |
+
summary["Comet Link"].append(f"[{hp_idx}][{hp_idx}]")
|
898 |
+
|
899 |
+
print(
|
900 |
+
"{}{}Summary{} {}:".format(
|
901 |
+
C.UNDERLINE,
|
902 |
+
C.OKGREEN,
|
903 |
+
C.ENDC,
|
904 |
+
f"{C.WARNING}(DEV){C.ENDC}" if dev else "",
|
905 |
+
)
|
906 |
+
)
|
907 |
+
print(
|
908 |
+
" "
|
909 |
+
+ "\n ".join(
|
910 |
+
"{:10}: {}".format(k, v) for k, v in tmp_template_dict.items()
|
911 |
+
)
|
912 |
+
)
|
913 |
+
print_footer()
|
914 |
+
|
915 |
+
print(f"\nRan a total of {len(hps)} jobs{' in dev mode.' if dev else '.'}\n")
|
916 |
+
|
917 |
+
table, sum_path = search_summary_table(summary, summary_dir if not dev else None)
|
918 |
+
if table is not None:
|
919 |
+
print(table)
|
920 |
+
print(
|
921 |
+
"Add `[i]: https://...` at the end of a markdown document",
|
922 |
+
"to fill in the comet links.\n",
|
923 |
+
)
|
924 |
+
if summary_dir is None:
|
925 |
+
print("Add summary_dir=path to store the printed markdown table ⇪")
|
926 |
+
else:
|
927 |
+
print("Saved table in", str(sum_path))
|
928 |
+
|
929 |
+
if not dev:
|
930 |
+
print(
|
931 |
+
"Cancel entire experiment? \n$ scancel",
|
932 |
+
" ".join(map(str, summary["Slurm JOBID"])),
|
933 |
+
)
|
shared/experiment/showcase.yaml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--- # ---------------------------
|
2 |
+
# `sample` can be
|
3 |
+
# - `uniform` (np.random.uniform(*from))
|
4 |
+
# - `range` (np.choice(np.arange(*from)))
|
5 |
+
# - `list` (np.choice(from))
|
6 |
+
# - `cartesian` special case where a cartesian product of all keys with the `cartesian` sampling scheme
|
7 |
+
# is created and iterated over in order. `from` MUST be a list
|
8 |
+
# As we iterate over the cartesian product of all
|
9 |
+
# such keys, others are sampled as usual. If n_search is larger than the size of the cartesian
|
10 |
+
# product, it will cycle again through the product in the same order
|
11 |
+
# example with A being `cartesian` from [1, 2] and B from [y, z] and 5 searches:
|
12 |
+
# => {A:1, B: y}, {A:1, B: z}, {A:2, B: y}, {A:2, B: z}, {A:1, B: y}
|
13 |
+
# - `sequential` samples will loop through the values in `from`. `from` MUST be a list
|
14 |
+
|
15 |
+
# ---------------------------
|
16 |
+
# ----- SBATCH config -----
|
17 |
+
cpus: 8
|
18 |
+
partition: long
|
19 |
+
mem: 32G
|
20 |
+
gres: "gpu:rtx8000:1"
|
21 |
+
codeloc: $HOME/ccai/climategan
|
22 |
+
|
23 |
+
modules: "module load anaconda/3 && module load pytorch"
|
24 |
+
conda: "conda activate climatenv && conda deactivate && conda activate climatenv"
|
25 |
+
|
26 |
+
n_search: -1
|
27 |
+
|
28 |
+
# ------------------------
|
29 |
+
# ----- Train Args -----
|
30 |
+
# ------------------------
|
31 |
+
|
32 |
+
"args.note": "Hyper Parameter search #1"
|
33 |
+
"args.comet_tags": ["masker_search", "v1"]
|
34 |
+
"args.config": "config/trainer/my_config.yaml"
|
35 |
+
|
36 |
+
# --------------------------
|
37 |
+
# ----- Model config -----
|
38 |
+
# --------------------------
|
39 |
+
"gen.opt.lr":
|
40 |
+
sample: list
|
41 |
+
from: [0.01, 0.001, 0.0001, 0.00001]
|
42 |
+
|
43 |
+
"dis.opt.lr":
|
44 |
+
sample: uniform
|
45 |
+
from: [0.01, 0.001]
|
46 |
+
|
47 |
+
"dis.opt.optimizer":
|
48 |
+
sample: cartesian
|
49 |
+
from:
|
50 |
+
- ExtraAdam
|
51 |
+
- Adam
|
52 |
+
|
53 |
+
"gen.opt.optimizer":
|
54 |
+
sample: cartesian
|
55 |
+
from:
|
56 |
+
- ExtraAdam
|
57 |
+
- Adam
|
58 |
+
|
59 |
+
"gen.lambdas.C":
|
60 |
+
sample: cartesian
|
61 |
+
from:
|
62 |
+
- 0.1
|
63 |
+
- 0.5
|
64 |
+
- 1
|
65 |
+
|
66 |
+
"data.loaders.batch_size":
|
67 |
+
sample: sequential
|
68 |
+
from:
|
69 |
+
- 2
|
70 |
+
- 4
|
71 |
+
- 6
|
shared/template/mila_victor.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition={partition}
|
3 |
+
#SBATCH --cpus-per-task={cpus}
|
4 |
+
#SBATCH --mem={mem}
|
5 |
+
#SBATCH --gres={gres}
|
6 |
+
#SBATCH --output={output}
|
7 |
+
|
8 |
+
module purge
|
9 |
+
|
10 |
+
{modules}
|
11 |
+
|
12 |
+
{conda}
|
13 |
+
|
14 |
+
export PYTHONUNBUFFERED=1
|
15 |
+
|
16 |
+
cd {codeloc}
|
17 |
+
|
18 |
+
echo "Currently using:"
|
19 |
+
echo $(which python)
|
20 |
+
echo "in:"
|
21 |
+
echo $(pwd)
|
22 |
+
echo "sbatch file name: $0"
|
23 |
+
|
24 |
+
python train.py {train_args}
|
shared/template/resume_mila_victor.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --partition={partition}
|
3 |
+
#SBATCH --cpus-per-task={cpus}
|
4 |
+
#SBATCH --mem={mem}
|
5 |
+
#SBATCH --gres={gres}
|
6 |
+
#SBATCH --output={output}
|
7 |
+
|
8 |
+
module purge
|
9 |
+
|
10 |
+
{modules}
|
11 |
+
|
12 |
+
{conda}
|
13 |
+
|
14 |
+
export PYTHONUNBUFFERED=1
|
15 |
+
|
16 |
+
cd {codeloc}
|
17 |
+
|
18 |
+
echo "Currently using:"
|
19 |
+
echo $(which python)
|
20 |
+
echo "in:"
|
21 |
+
echo $(pwd)
|
22 |
+
echo "sbatch file: $0"
|
23 |
+
|
24 |
+
python resume.py --path {resume}
|
shared/trainer/config.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HYDRA CONFIG
|
2 |
+
|
3 |
+
# defaults:
|
4 |
+
# - defaults
|
5 |
+
|
6 |
+
args:
|
7 |
+
config: null # "What configuration file to use to overwrite shared/defaults.yaml"
|
8 |
+
note: null # Note about this training for comet logging
|
9 |
+
no_comet: False # DON'T use comet.ml to log experiment
|
10 |
+
resume: False # Load latest ckpt
|
11 |
+
tags: null
|
12 |
+
dev: False # Run this script in development mode
|
13 |
+
|
14 |
+
hydra:
|
15 |
+
run:
|
16 |
+
dir: .
|
shared/trainer/defaults.yaml
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_path: /miniscratch/_groups/ccai/trash
|
2 |
+
# README on load_path
|
3 |
+
# 1/ any path which leads to a dir will be loaded as `path / checkpoints / latest_ckpt.pth`
|
4 |
+
# 2/ if you want to specify a specific checkpoint, it MUST be a `.pth` file
|
5 |
+
# 3/ resuming a P OR an M model, you may only specify 1 of `load_path.p` OR `load_path.m`.
|
6 |
+
# You may also leave BOTH at none, in which case `output_path / checkpoints / latest_ckpt.pth`
|
7 |
+
# will be used
|
8 |
+
# 4/ resuming a P+M model, you may specify (`p` AND `m`) OR `pm` OR leave all at none,
|
9 |
+
# in which case `output_path / checkpoints / latest_ckpt.pth` will be used to load from
|
10 |
+
# a single checkpoint
|
11 |
+
load_paths:
|
12 |
+
p: none # Painter weights: none will use `output_path / checkpoints / latest_ckpt.pth`
|
13 |
+
m: none # Masker weights: none will use `output_path / checkpoints / latest_ckpt.pth`
|
14 |
+
pm: none # Painter and Masker weights: none will use `output_path / checkpoints / latest_ckpt.pth`
|
15 |
+
|
16 |
+
# -------------------
|
17 |
+
# ----- Tasks -----
|
18 |
+
# -------------------
|
19 |
+
tasks: [d, s, m, p] # [p] [m, s, d]
|
20 |
+
|
21 |
+
# ----------------
|
22 |
+
# ----- Data -----
|
23 |
+
# ----------------
|
24 |
+
data:
|
25 |
+
max_samples: -1 # -1 for all, otherwise set to an int to crop the training data size
|
26 |
+
files: # if one is not none it will override the dirs location
|
27 |
+
base: /miniscratch/_groups/ccai/data/jsons
|
28 |
+
train:
|
29 |
+
r: train_r_full.json
|
30 |
+
s: train_s_fixedholes.json
|
31 |
+
rf: train_rf.json
|
32 |
+
kitti: train_kitti.json
|
33 |
+
val:
|
34 |
+
r: val_r_full.json
|
35 |
+
s: val_s_fixedholes.json
|
36 |
+
rf: val_rf_labelbox.json
|
37 |
+
kitti: val_kitti.json
|
38 |
+
check_samples: False
|
39 |
+
loaders:
|
40 |
+
batch_size: 6
|
41 |
+
num_workers: 6
|
42 |
+
normalization: default # can be "default" or "HRNet" for now. # default: mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]; HRNet: mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
43 |
+
transforms:
|
44 |
+
- name: hflip
|
45 |
+
ignore: val
|
46 |
+
p: 0.5
|
47 |
+
- name: resize
|
48 |
+
ignore: false
|
49 |
+
new_size: 640
|
50 |
+
keep_aspect_ratio: true # smallest dimension will be `new_size` and the other will be computed to keep aspect ratio
|
51 |
+
- name: crop
|
52 |
+
ignore: false
|
53 |
+
center: val # disable randomness, crop around the image's center
|
54 |
+
height: 600
|
55 |
+
width: 600
|
56 |
+
- name: brightness
|
57 |
+
ignore: val
|
58 |
+
- name: saturation
|
59 |
+
ignore: val
|
60 |
+
- name: contrast
|
61 |
+
ignore: val
|
62 |
+
- name: resize
|
63 |
+
ignore: false
|
64 |
+
new_size:
|
65 |
+
default: 640
|
66 |
+
d: 160
|
67 |
+
s: 160
|
68 |
+
|
69 |
+
# ---------------------
|
70 |
+
# ----- Generator -----
|
71 |
+
# ---------------------
|
72 |
+
gen:
|
73 |
+
opt:
|
74 |
+
optimizer: ExtraAdam # one in [Adam, ExtraAdam] default: Adam
|
75 |
+
beta1: 0.9
|
76 |
+
lr:
|
77 |
+
default: 0.00005 # 0.00001 for dlv2, 0.00005 for dlv3
|
78 |
+
lr_policy: step
|
79 |
+
# lr_policy can be constant, step or multi_step; if step, specify lr_step_size and lr_gamma
|
80 |
+
# if multi_step specify lr_step_size lr_gamma and lr_milestones:
|
81 |
+
# if lr_milestones is a list:
|
82 |
+
# the learning rate will be multiplied by gamma each time the epoch reaches an
|
83 |
+
# item in the list (no need for lr_step_size).
|
84 |
+
# if lr_milestones is an int:
|
85 |
+
# a list of milestones is created from `range(lr_milestones, train.epochs, lr_step_size)`
|
86 |
+
lr_step_size: 5 # for linear decay : period of learning rate decay (epochs)
|
87 |
+
lr_milestones: 15
|
88 |
+
lr_gamma: 0.5 # Multiplicative factor of learning rate decay
|
89 |
+
default:
|
90 |
+
&default-gen # default parameters for the generator (encoder and decoders)
|
91 |
+
activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
|
92 |
+
init_gain: 0.02
|
93 |
+
init_type: xavier
|
94 |
+
n_res: 1 # number of residual blocks before upsampling
|
95 |
+
n_downsample: &n_downsample 3 # number of downsampling layers in encoder | dim 32 + down 3 => z = 256 x 32 x 32
|
96 |
+
n_upsample: *n_downsample # upsampling in spade decoder ; should match encoder.n_downsample
|
97 |
+
pad_type: reflect # padding type [zero/reflect]
|
98 |
+
norm: spectral # ResBlock normalization ; one of {"batch", "instance", "layer", "adain", "spectral", "none"}
|
99 |
+
proj_dim: 32 # Dim of projection from latent space
|
100 |
+
encoder: # specific params for the encoder
|
101 |
+
<<: *default-gen
|
102 |
+
dim: 32
|
103 |
+
architecture: deeplabv3 # [deeplabv2/v3 resnet -> res_dim=2048) | dlv3 mobilenet -> res_dim=320
|
104 |
+
input_dim: 3 # input number of channels
|
105 |
+
n_res: 0 # number of residual blocks in content encoder/decoder
|
106 |
+
norm: spectral # ConvBlock normalization ; one of {"batch", "instance", "layer", "adain", "spectral", "none"}
|
107 |
+
|
108 |
+
#! Don't change!!!
|
109 |
+
deeplabv2:
|
110 |
+
nblocks: [3, 4, 23, 3]
|
111 |
+
use_pretrained: True
|
112 |
+
pretrained_model: "/miniscratch/_groups/ccai/data/pretrained_models/deeplabv2/DeepLab_resnet_pretrained_imagenet.pth"
|
113 |
+
|
114 |
+
deeplabv3:
|
115 |
+
backbone: resnet # resnet or mobilenet
|
116 |
+
output_stride: 8 # 8 or 16
|
117 |
+
use_pretrained: true
|
118 |
+
pretrained_model:
|
119 |
+
mobilenet: "/miniscratch/_groups/ccai/data/pretrained_models/deeplabv3/deeplabv3_plus_mobilenetv2_segmentron.pth"
|
120 |
+
resnet: "/miniscratch/_groups/ccai/data/pretrained_models/deeplabv3/model_CoinCheungDeepLab-v3-plus.pth"
|
121 |
+
|
122 |
+
d: # specific params for the depth estimation decoder
|
123 |
+
<<: *default-gen
|
124 |
+
output_dim: 1
|
125 |
+
norm: batch
|
126 |
+
loss: sigm # dada or sigm | /!\ ignored if classify.enable
|
127 |
+
upsample_featuremaps: True # upsamples from 80x80 to 160x160 intermediate feature maps
|
128 |
+
architecture: dada # dada or base | must be base for classif
|
129 |
+
classify: # classify log-depth instead of regression
|
130 |
+
enable: False
|
131 |
+
linspace:
|
132 |
+
min: 0.35
|
133 |
+
max: 6.95
|
134 |
+
buckets: 256
|
135 |
+
s: # specific params for the semantic segmentation decoder
|
136 |
+
<<: *default-gen
|
137 |
+
num_classes: 11
|
138 |
+
output_dim: 11
|
139 |
+
use_advent: True
|
140 |
+
use_minent: True
|
141 |
+
architecture: deeplabv3
|
142 |
+
upsample_featuremaps: False # upsamples from 80x80 to 160x160 intermediate feature maps
|
143 |
+
use_dada: True
|
144 |
+
p: # specific params for the SPADE painter
|
145 |
+
<<: *default-gen
|
146 |
+
latent_dim: 640
|
147 |
+
loss: gan # gan or hinge
|
148 |
+
no_z: true # <=> use_vae=False in the SPADE repo
|
149 |
+
output_dim: 3 # output dimension
|
150 |
+
pad_type: reflect # padding type [zero/reflect]
|
151 |
+
paste_original_content: True # only select the water painted to backprop through the network, not the whole generated image: fake_flooded = masked_x + m * fake_flooded
|
152 |
+
pl4m_epoch: 49 # epoch from which we introduce a new loss to the masker: the painter's discriminator's loss
|
153 |
+
spade_kernel_size: 3 # kernel size within SPADE norm layers
|
154 |
+
spade_n_up: 7 # number of upsampling layers in the translation decoder is equal to number of downsamplings in the encoder. output's h and w are z's h and w x 2^spade_num_upsampling_layers | z:32 and spade_n_up:4 => output 512
|
155 |
+
spade_param_free_norm: instance # what param-free normalization to apply in SPADE normalization
|
156 |
+
spade_use_spectral_norm: true
|
157 |
+
use_final_shortcut: False # if true, the last spade block does not get the masked input as conditioning but the prediction of the previous layer (passed through a conv to match dims) in order to lighten the masking restrictions and have smoother edges
|
158 |
+
diff_aug:
|
159 |
+
use: False
|
160 |
+
do_color_jittering: false
|
161 |
+
do_cutout: false
|
162 |
+
cutout_ratio: 0.5
|
163 |
+
do_translation: false
|
164 |
+
translation_ratio: 0.125
|
165 |
+
|
166 |
+
m: # specific params for the mask-generation decoder
|
167 |
+
<<: *default-gen
|
168 |
+
use_spade: False
|
169 |
+
output_dim: 1
|
170 |
+
use_minent: True # directly minimize the entropy of the image
|
171 |
+
use_minent_var: True # add variance of entropy map in the measure of entropy for a certain picture
|
172 |
+
use_advent: True # minimize the entropy of the image by adversarial training
|
173 |
+
use_ground_intersection: True
|
174 |
+
use_proj: True
|
175 |
+
proj_dim: 64
|
176 |
+
use_pl4m: False
|
177 |
+
n_res: 3
|
178 |
+
use_low_level_feats: True
|
179 |
+
use_dada: False
|
180 |
+
spade:
|
181 |
+
latent_dim: 128
|
182 |
+
detach: false # detach s_pred and d_pred conditioning tensors
|
183 |
+
cond_nc: 15 # 12 without x, 15 with x
|
184 |
+
spade_use_spectral_norm: True
|
185 |
+
spade_param_free_norm: batch
|
186 |
+
num_layers: 3
|
187 |
+
activations:
|
188 |
+
all_lrelu: True
|
189 |
+
|
190 |
+
# -------------------------
|
191 |
+
# ----- Discriminator -----
|
192 |
+
# -------------------------
|
193 |
+
dis:
|
194 |
+
soft_shift: 0.2 # label smoothing: real in U(1-soft_shift, 1), fake in U(0, soft_shift) # ! one-sided label smoothing
|
195 |
+
flip_prob: 0.05 # label flipping
|
196 |
+
opt:
|
197 |
+
optimizer: ExtraAdam # one in [Adam, ExtraAdam] default: Adam
|
198 |
+
beta1: 0.5
|
199 |
+
lr:
|
200 |
+
default: 0.00002 # 0.0001 for dlv2, 0.00002 for dlv3
|
201 |
+
lr_policy: step
|
202 |
+
# lr_policy can be constant, step or multi_step; if step, specify lr_step_size and lr_gamma
|
203 |
+
# if multi_step specify lr_step_size lr_gamma and lr_milestones:
|
204 |
+
# if lr_milestones is a list:
|
205 |
+
# the learning rate will be multiplied by gamma each time the epoch reaches an
|
206 |
+
# item in the list (no need for lr_step_size).
|
207 |
+
# if lr_milestones is an int:
|
208 |
+
# a list of milestones is created from `range(lr_milestones, train.epochs, lr_step_size)`
|
209 |
+
lr_step_size: 15 # for linear decay : period of learning rate decay (epochs)
|
210 |
+
lr_milestones: 5
|
211 |
+
lr_gamma: 0.5 # Multiplicative factor of learning rate decay
|
212 |
+
default:
|
213 |
+
&default-dis # default setting for discriminators (there are 4 of them for rn rf sn sf)
|
214 |
+
input_nc: 3
|
215 |
+
ndf: 64
|
216 |
+
n_layers: 4
|
217 |
+
norm: instance
|
218 |
+
init_type: xavier
|
219 |
+
init_gain: 0.02
|
220 |
+
use_sigmoid: false
|
221 |
+
num_D: 1 #Number of discriminators to use (>1 means multi-scale)
|
222 |
+
get_intermediate_features: false
|
223 |
+
p:
|
224 |
+
<<: *default-dis
|
225 |
+
num_D: 3
|
226 |
+
get_intermediate_features: true
|
227 |
+
use_local_discriminator: false
|
228 |
+
# ttur: false # two time-scale update rule (see SPADE repo)
|
229 |
+
m:
|
230 |
+
<<: *default-dis
|
231 |
+
multi_level: false
|
232 |
+
architecture: base # can be [base | OmniDiscriminator]
|
233 |
+
gan_type: WGAN_norm # can be [GAN | WGAN | WGAN_gp | WGAN_norm]
|
234 |
+
wgan_clamp_lower: -0.01 # used in WGAN, WGAN clap the params in dis to [wgan_clamp_lower, wgan_clamp_upper] for every update
|
235 |
+
wgan_clamp_upper: 0.01 # used in WGAN
|
236 |
+
s:
|
237 |
+
<<: *default-dis
|
238 |
+
gan_type: WGAN_norm # can be [GAN | WGAN | WGAN_gp | WGAN_norm]
|
239 |
+
wgan_clamp_lower: -0.01 # used in WGAN, WGAN clap the params in dis to [wgan_clamp_lower, wgan_clamp_upper] for every update
|
240 |
+
wgan_clamp_upper: 0.01 # used in WGAN
|
241 |
+
# -------------------------------
|
242 |
+
# ----- Domain Classifier -----
|
243 |
+
# -------------------------------
|
244 |
+
classifier:
|
245 |
+
opt:
|
246 |
+
optimizer: ExtraAdam # one in [Adam, ExtraAdam] default: Adam
|
247 |
+
beta1: 0.5
|
248 |
+
lr:
|
249 |
+
default: 0.0005
|
250 |
+
lr_policy: step # constant or step ; if step, specify step_size and gamma
|
251 |
+
lr_step_size: 30 # for linear decay
|
252 |
+
lr_gamma: 0.5
|
253 |
+
loss: l2 #Loss can be l1, l2, cross_entropy. default cross_entropy
|
254 |
+
layers: [100, 100, 20, 20, 4] # number of units per hidden layer ; las number is output_dim
|
255 |
+
dropout: 0.4 # probability of being set to 0
|
256 |
+
init_type: kaiming
|
257 |
+
init_gain: 0.2
|
258 |
+
proj_dim: 128 #Dim of projection from latent space
|
259 |
+
|
260 |
+
# ------------------------
|
261 |
+
# ----- Train Params -----
|
262 |
+
# ------------------------
|
263 |
+
train:
|
264 |
+
kitti:
|
265 |
+
pretrain: False
|
266 |
+
epochs: 10
|
267 |
+
batch_size: 6
|
268 |
+
amp: False
|
269 |
+
pseudo:
|
270 |
+
tasks: [] # list of tasks for which to use pseudo labels (empty list to disable)
|
271 |
+
epochs: 10 # disable pseudo training after n epochs (set to -1 to never disable)
|
272 |
+
epochs: 300
|
273 |
+
fid:
|
274 |
+
n_images: 57 # val_rf.json has 57 images
|
275 |
+
batch_size: 50 # inception inference batch size, not painter's
|
276 |
+
dims: 2048 # what Inception bock to compute the stats from (see BLOCK_INDEX_BY_DIM in fid.py)
|
277 |
+
latent_domain_adaptation: False # whether or not to do domain adaptation on the latent vectors # Needs to be turned off if use_advent is True
|
278 |
+
lambdas: # scaling factors in the total loss
|
279 |
+
G:
|
280 |
+
d:
|
281 |
+
main: 1
|
282 |
+
gml: 0.5
|
283 |
+
s:
|
284 |
+
crossent: 1
|
285 |
+
crossent_pseudo: 0.001
|
286 |
+
minent: 0.001
|
287 |
+
advent: 0.001
|
288 |
+
m:
|
289 |
+
bce: 1 # Main prediction loss, i.e. GAN or BCE
|
290 |
+
tv: 1 # Total variational loss (for smoothing)
|
291 |
+
gi: 0.05
|
292 |
+
pl4m: 1 # painter loss for the masker (end-to-end)
|
293 |
+
p:
|
294 |
+
context: 0
|
295 |
+
dm: 1 # depth matching
|
296 |
+
featmatch: 10
|
297 |
+
gan: 1 # gan loss
|
298 |
+
reconstruction: 0
|
299 |
+
tv: 0
|
300 |
+
vgg: 10
|
301 |
+
classifier: 1
|
302 |
+
C: 1
|
303 |
+
advent:
|
304 |
+
ent_main: 0.5 # the coefficient of the MinEnt loss that directly minimize the entropy of the image
|
305 |
+
ent_aux: 0.0 # the corresponding coefficient of the MinEnt loss of second output
|
306 |
+
ent_var: 0.1 # the proportion of variance of entropy map in the entropy measure for a certain picture
|
307 |
+
adv_main: 1.0 # the coefficient of the AdvEnt loss that minimize the entropy of the image by adversarial training
|
308 |
+
adv_aux: 0.0 # the corresponding coefficient of the AdvEnt loss of second output
|
309 |
+
dis_main: 1.0 # the discriminator take care of the first output in the adversarial training
|
310 |
+
dis_aux: 0.0 # the discriminator take care of the second output in the adversarial training
|
311 |
+
WGAN_gp: 10 # used in WGAN_gp, it's the hyperparameters for the gradient penalty
|
312 |
+
log_level: 2 # 0: no log, 1: only aggregated losses, >1 detailed losses
|
313 |
+
save_n_epochs: 25 # Save `latest_ckpt.pth` every epoch, `epoch_{epoch}_ckpt.pth` model every n epochs if epoch >= min_save_epoch
|
314 |
+
min_save_epoch: 28 # Save extra intermediate checkpoints when epoch > min_save_epoch
|
315 |
+
resume: false # Load latest_ckpt.pth checkpoint from `output_path` #TODO Make this path of checkpoint to load
|
316 |
+
auto_resume: true # automatically looks for similar output paths and exact same jobID to resume training automatically even if resume is false.
|
317 |
+
|
318 |
+
# -----------------------------
|
319 |
+
# ----- Validation Params -----
|
320 |
+
# -----------------------------
|
321 |
+
val:
|
322 |
+
store_images: false # write to disk on top of comet logging
|
323 |
+
val_painter: /miniscratch/_groups/ccai/checkpoints/painter/victor/good_large_lr/checkpoints/latest_ckpt.pth
|
324 |
+
# -----------------------------
|
325 |
+
# ----- Comet Params ----------
|
326 |
+
# -----------------------------
|
327 |
+
comet:
|
328 |
+
display_size: 20
|
329 |
+
rows_per_log: 5 # number of samples (rows) in a logged grid image. Number of total logged images: display_size // rows_per_log
|
330 |
+
im_per_row: # how many columns (3 = x, target, pred)
|
331 |
+
p: 4
|
332 |
+
m: 6
|
333 |
+
s: 4
|
334 |
+
d: 4
|