ccolas commited on
Commit
41d4286
1 Parent(s): 1a5d300

Update src/music2cocktailrep/training/latent_translation/setup_trained_model.py

Browse files
src/music2cocktailrep/training/latent_translation/setup_trained_model.py CHANGED
@@ -33,8 +33,6 @@ def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH):
33
  latent_dim=params['latent_dim'],
34
  nb_classes=params['nb_classes'],
35
  dropout=params['dropout'])
36
- print('HEREEE: ', torch.sum(torch.Tensor([param.sum() for param in list(model.parameters())])))
37
- print('model hash: ', hashlib.md5(open(model_path, 'rb').read()).hexdigest())
38
  stats = params['stats']
39
  stats_music = np.array(stats['mean_std_music_rep'])
40
  stats_cocktail = np.array(stats['mean_std_cocktail_rep_norm11'])
@@ -44,7 +42,8 @@ def setup_trained_model(checkpoint_path=TRANSLATION_VAE_CHKP_PATH):
44
 
45
  model.load_state_dict(torch.load(model_path))
46
  model.eval()
47
-
 
48
  def denormalize_cocktail_output(output):
49
  return output * stats_cocktail[1] + stats_cocktail[0]
50
 
 
33
  latent_dim=params['latent_dim'],
34
  nb_classes=params['nb_classes'],
35
  dropout=params['dropout'])
 
 
36
  stats = params['stats']
37
  stats_music = np.array(stats['mean_std_music_rep'])
38
  stats_cocktail = np.array(stats['mean_std_cocktail_rep_norm11'])
 
42
 
43
  model.load_state_dict(torch.load(model_path))
44
  model.eval()
45
+ print('HEREEE: ', torch.sum(torch.Tensor([param.sum() for param in list(model.parameters())])))
46
+ print('model hash: ', hashlib.md5(open(model_path, 'rb').read()).hexdigest())
47
  def denormalize_cocktail_output(output):
48
  return output * stats_cocktail[1] + stats_cocktail[0]
49