Biomap / biomap /unet.py
jeremyLE-Ekimetrics's picture
first commit
5c718d1
raw
history blame
2.86 kB
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch.nn as nn
from collections import defaultdict
import torchvision
import torch.nn.functional as F
from torch.utils.data.sampler import Sampler
class Block(nn.Module):
def __init__(self, in_ch, out_ch, padding='same'):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=padding)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=padding)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class Encoder(nn.Module):
def __init__(self, chs=(3,32,64,128,256)):
super().__init__()
self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
self.pool = nn.MaxPool2d(2)
def forward(self, x):
ftrs = []
for block in self.enc_blocks:
x = block(x)
ftrs.append(x)
x = self.pool(x)
return ftrs
class Decoder(nn.Module):
def __init__(self, chs=(256,128, 64, 32), aux_ch=70):
super().__init__()
upchs = tuple([chs[i]+aux_ch if i == 0 else chs[i] for i in range(len(chs))])
self.chs = chs
self.upchs = upchs
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(upchs[i], upchs[i+1], 2, 2) for i in range(len(upchs)-1)])
self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
def forward(self, x, encoder_features):
for i in range(len(self.chs)-1):
# pdb.set_trace()
x = self.upconvs[i](x)
enc_ftrs = self.crop(encoder_features[i], x)
x = torch.cat([x, enc_ftrs], dim=1)
x = self.dec_blocks[i](x)
return x
def crop(self, enc_ftrs, x):
_, _, H, W = x.shape
enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
return enc_ftrs
class AuxUNet(nn.Module):
# UNet with auxiliary feature at the bottom
def __init__(self, enc_chs=(3,32,64,128,256), dec_chs=(256,128, 64, 32), aux_ch=70, num_class=7, retain_dim=False, out_sz=(224,224)):
super().__init__()
self.encoder = Encoder(enc_chs)
self.decoder = Decoder(dec_chs, aux_ch)
self.head = nn.Conv2d(dec_chs[-1], num_class, 1)
self.retain_dim = retain_dim
def forward(self, x, aux):
# aux: auxiliary feature at the bottom
enc_ftrs = self.encoder(x)
enc_ftrs[-1] = torch.cat((enc_ftrs[-1], aux), 1)
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
out = self.head(out)
if self.retain_dim:
out = F.interpolate(out, out_sz)
return out