|
import numpy as np |
|
import torch, h5py |
|
from model import * |
|
w, h, c = 28, 28, 1 |
|
model_new = DeepConvAE( |
|
w=w, h=h, c=c, |
|
nb_filters=128, |
|
spatial=True, |
|
channel=True, |
|
channel_stride=4, |
|
|
|
nb_layers=3, |
|
) |
|
|
|
model_old = h5py.File("/home/mehdi/work/code/out_of_class/ae/mnist/model.h5") |
|
|
|
|
|
print(model_new) |
|
print(model_old["model_weights"].keys()) |
|
|
|
|
|
for name, param in model_new.named_parameters(): |
|
enc_or_decode, layer_id, bias_or_kernel = name.split(".") |
|
|
|
if enc_or_decode == "encode": |
|
layer_name = "conv2d" |
|
else: |
|
layer_name = "up_conv2d" |
|
|
|
layer_id = (int(layer_id)//2) + 1 |
|
|
|
full_layer_name = f"{layer_name}_{layer_id}" |
|
print(full_layer_name) |
|
|
|
k = "kernel" if bias_or_kernel == "weight" else "bias" |
|
weights = model_old["model_weights"][full_layer_name][full_layer_name][k][()] |
|
weights = np.array(weights) |
|
weights = torch.from_numpy(weights) |
|
print(name, layer_id, param.shape, weights.shape) |
|
inds = [4,3,2,1,0] |
|
if k == "kernel": |
|
if layer_name == "conv2d": |
|
weights = weights.permute((3,2,0,1)) |
|
weights = weights[:,:,inds] |
|
weights = weights[:,:,:, inds] |
|
print("W", weights.shape) |
|
elif layer_name == "up_conv2d": |
|
weights = weights.permute((2,3,0,1)) |
|
print(param.shape, weights.shape) |
|
param.data.copy_(weights) |
|
print((param-weights).sum()) |
|
torch.save(model_new, "mnist_deepconvae/model.th") |
|
|