Spaces:
Runtime error
Runtime error
Update src/music2cocktailrep/training/latent_translation/setup_trained_model.py
Browse files
src/music2cocktailrep/training/latent_translation/setup_trained_model.py
CHANGED
@@ -33,8 +33,6 @@ def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH):
|
|
33 |
latent_dim=params['latent_dim'],
|
34 |
nb_classes=params['nb_classes'],
|
35 |
dropout=params['dropout'])
|
36 |
-
print('HEREEE: ', torch.sum(torch.Tensor([param.sum() for param in list(model.parameters())])))
|
37 |
-
print('model hash: ', hashlib.md5(open(model_path, 'rb').read()).hexdigest())
|
38 |
stats = params['stats']
|
39 |
stats_music = np.array(stats['mean_std_music_rep'])
|
40 |
stats_cocktail = np.array(stats['mean_std_cocktail_rep_norm11'])
|
@@ -44,7 +42,8 @@ def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH):
|
|
44 |
|
45 |
model.load_state_dict(torch.load(model_path))
|
46 |
model.eval()
|
47 |
-
|
|
|
48 |
def denormalize_cocktail_output(output):
|
49 |
return output * stats_cocktail[1] + stats_cocktail[0]
|
50 |
|
|
|
33 |
latent_dim=params['latent_dim'],
|
34 |
nb_classes=params['nb_classes'],
|
35 |
dropout=params['dropout'])
|
|
|
|
|
36 |
stats = params['stats']
|
37 |
stats_music = np.array(stats['mean_std_music_rep'])
|
38 |
stats_cocktail = np.array(stats['mean_std_cocktail_rep_norm11'])
|
|
|
42 |
|
43 |
model.load_state_dict(torch.load(model_path))
|
44 |
model.eval()
|
45 |
+
print('HEREEE: ', torch.sum(torch.Tensor([param.sum() for param in list(model.parameters())])))
|
46 |
+
print('model hash: ', hashlib.md5(open(model_path, 'rb').read()).hexdigest())
|
47 |
def denormalize_cocktail_output(output):
|
48 |
return output * stats_cocktail[1] + stats_cocktail[0]
|
49 |
|