TastyPiano / src /music /pipeline /encoded2rep.py
ccolas's picture
Update src/music/pipeline/encoded2rep.py
004e4fc
raw
history blame
3.62 kB
from src.music.utilities.representation_learning_utilities.constants import *
from src.music.config import REP_MODEL_NAME
from src.music.utils import get_out_path
import pickle
import numpy as np
# from transformers import AutoModel, AutoTokenizer
from torch import nn
from src.music.representation_learning.sentence_transfo.sentence_transformers import SentenceTransformer
class Argument(object):
def __init__(self, adict):
self.__dict__.update(adict)
class RepModel(nn.Module):
def __init__(self, model, model_name):
super().__init__()
if 't5' in model_name:
self.model = model.get_encoder()
else:
self.model = model
self.model.eval()
def forward(self, inputs):
with torch.no_grad():
out = self.model(inputs, output_hidden_states=True)
embeddings = out.hidden_states[-1]
return torch.mean(embeddings[0], dim=0)
# def get_trained_music_LM(model_name):
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
# model = RepModel(AutoModel.from_pretrained(model_name, use_auth_token=True), model_name)
#
# return model, tokenizer
def get_trained_sentence_embedder(model_name):
model = SentenceTransformer(model_name)
return model
MODEL = get_trained_sentence_embedder(REP_MODEL_NAME)
def encoded2rep(encoded_path, rep_path=None, return_rep=False, verbose=False, level=0):
if not rep_path:
rep_path, _, _ = get_out_path(in_path=encoded_path, in_word='encoded', out_word='represented', out_extension='.txt')
error_msg = 'Error in music transformer mapping.'
if verbose: print(' ' * level + 'Mapping to final music representations')
try:
error_msg += ' Error in encoded file loading?'
with open(encoded_path, 'rb') as f:
data = pickle.load(f)
performance = [str(w) for w in data['main'] if w != 1]
assert len(performance) % 5 == 0
if(len(performance) == 0):
error_msg += " Error: No midi messages in primer file"
assert False
error_msg += ' Nope, error in tokenization?'
perf = ' '.join(performance)
# tokenized = torch.IntTensor(TOKENIZER.encode(perf)).unsqueeze(dim=0)
error_msg += ' Nope. Maybe in performance encoding?'
# reps = []
# for i_chunk in range(min(tokenized.shape[1] // 510 - 1, 8)):
# chunk_tokenized = tokenized[:, i_chunk * 510: (i_chunk + 1) * 510 + 2]
# rep = MODEL(chunk_tokenized)
# reps.append(rep.detach().numpy())
# representation = np.mean(reps, axis=0)
p = [int(p) for p in perf.split(' ')]
print('PERF:', np.sum(p), perf)
representation = MODEL.encode(perf)
print('model weights sum: ', torch.sum(torch.Tensor([param.sum() for param in list(MODEL.parameters())])))
print('reprep', representation)
error_msg += ' Nope. Saving performance?'
np.savetxt(rep_path, representation)
error_msg += ' Nope.'
if verbose: print(' ' * (level + 2) + 'Success.')
if return_rep:
return rep_path, representation, ''
else:
return rep_path, ''
except:
if verbose: print(' ' * (level + 2) + f'Failed with error: {error_msg}')
if return_rep:
return None, None, error_msg
else:
return None, error_msg
if __name__ == "__main__":
representation = encoded2rep("/home/cedric/Documents/pianocktail/data/music/encoded/single_videos_midi_processed_encoded/chris_dawson_all_of_me_.pickle")
stop = 1