Spaces:
Runtime error
Runtime error
import pretty_midi | |
from copy import deepcopy | |
import numpy as np | |
from miditok import CPWord, Structured | |
from miditoolkit import MidiFile | |
from src.music.config import MAX_EMBEDDING, CHUNK_SIZE | |
from src.music.utilities.chord_structured import ChordStructured | |
# code from https://github.com/jason9693/midi-neural-processor | |
RANGE_NOTE_ON = 128 | |
RANGE_NOTE_OFF = 128 | |
RANGE_VEL = 32 | |
RANGE_TIME_SHIFT = 100 | |
MAX_EMBEDDING = RANGE_VEL + RANGE_NOTE_OFF + RANGE_TIME_SHIFT + RANGE_NOTE_ON | |
START_IDX = { | |
'note_on': 0, | |
'note_off': RANGE_NOTE_ON, | |
'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF, | |
'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT | |
} | |
# Our parameters | |
pitch_range = range(21, 109) | |
beat_res = {(0, 4): 8, (4, 12): 4} | |
nb_velocities = 32 | |
additional_tokens = {'Chord': True, 'Rest': True, 'Tempo': True, 'TimeSignature': False, 'Program': False, | |
'rest_range': (2, 8), # (half, 8 beats) | |
'nb_tempos': 32, # nb of tempo bins | |
'tempo_range': (40, 250)} # (min, max) | |
# Creates the tokenizer_cp and loads a MIDI | |
# tokenizer_cp = CPWord(pitch_range, beat_res, nb_velocities, additional_tokens) | |
tokenizer_structured = ChordStructured(pitch_range, beat_res, nb_velocities) | |
class SustainAdapter: | |
def __init__(self, time, type): | |
self.start = time | |
self.type = type | |
class SustainDownManager: | |
def __init__(self, start, end): | |
self.start = start | |
self.end = end | |
self.managed_notes = [] | |
self._note_dict = {} # key: pitch, value: note.start | |
def add_managed_note(self, note: pretty_midi.Note): | |
self.managed_notes.append(note) | |
def transposition_notes(self): | |
for note in reversed(self.managed_notes): | |
try: | |
note.end = self._note_dict[note.pitch] | |
except KeyError: | |
note.end = max(self.end, note.end) | |
self._note_dict[note.pitch] = note.start | |
# Divided note by note_on, note_off | |
class SplitNote: | |
def __init__(self, type, time, value, velocity): | |
## type: note_on, note_off | |
self.type = type | |
self.time = time | |
self.velocity = velocity | |
self.value = value | |
def __repr__(self): | |
return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\ | |
.format(self.time, self.type, self.value, self.velocity) | |
class Event: | |
def __init__(self, event_type, value): | |
self.type = event_type | |
self.value = value | |
def __repr__(self): | |
return '<Event type: {}, value: {}>'.format(self.type, self.value) | |
def to_int(self): | |
return START_IDX[self.type] + self.value | |
def from_int(int_value): | |
info = Event._type_check(int_value) | |
return Event(info['type'], info['value']) | |
def _type_check(int_value): | |
range_note_on = range(0, RANGE_NOTE_ON) | |
range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF) | |
range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT) | |
valid_value = int_value | |
if int_value in range_note_on: | |
return {'type': 'note_on', 'value': valid_value} | |
elif int_value in range_note_off: | |
valid_value -= RANGE_NOTE_ON | |
return {'type': 'note_off', 'value': valid_value} | |
elif int_value in range_time_shift: | |
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF) | |
return {'type': 'time_shift', 'value': valid_value} | |
else: | |
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT) | |
return {'type': 'velocity', 'value': valid_value} | |
def _divide_note(notes): | |
result_array = [] | |
notes.sort(key=lambda x: x.start) | |
for note in notes: | |
on = SplitNote('note_on', note.start, note.pitch, note.velocity) | |
off = SplitNote('note_off', note.end, note.pitch, None) | |
result_array += [on, off] | |
return result_array | |
def _merge_note(snote_sequence): | |
note_on_dict = {} | |
result_array = [] | |
for snote in snote_sequence: | |
# print(note_on_dict) | |
if snote.type == 'note_on': | |
note_on_dict[snote.value] = snote | |
elif snote.type == 'note_off': | |
try: | |
on = note_on_dict[snote.value] | |
off = snote | |
if off.time - on.time == 0: | |
continue | |
result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time) | |
result_array.append(result) | |
except: | |
print('info removed pitch: {}'.format(snote.value)) | |
return result_array | |
def _snote2events(snote: SplitNote, prev_vel: int): | |
result = [] | |
if snote.velocity is not None: | |
modified_velocity = snote.velocity // 4 | |
if prev_vel != modified_velocity: | |
result.append(Event(event_type='velocity', value=modified_velocity)) | |
result.append(Event(event_type=snote.type, value=snote.value)) | |
return result | |
def _event_seq2snote_seq(event_sequence): | |
timeline = 0 | |
velocity = 0 | |
snote_seq = [] | |
for event in event_sequence: | |
if event.type == 'time_shift': | |
timeline += ((event.value+1) / 100) | |
if event.type == 'velocity': | |
velocity = event.value * 4 | |
else: | |
snote = SplitNote(event.type, timeline, event.value, velocity) | |
snote_seq.append(snote) | |
return snote_seq | |
def _make_time_sift_events(prev_time, post_time): | |
time_interval = int(round((post_time - prev_time) * 100)) | |
results = [] | |
while time_interval >= RANGE_TIME_SHIFT: | |
results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1)) | |
time_interval -= RANGE_TIME_SHIFT | |
if time_interval == 0: | |
return results | |
else: | |
return results + [Event(event_type='time_shift', value=time_interval-1)] | |
def _control_preprocess(ctrl_changes): | |
sustains = [] | |
manager = None | |
for ctrl in ctrl_changes: | |
if ctrl.value >= 64 and manager is None: | |
# sustain down | |
manager = SustainDownManager(start=ctrl.time, end=None) | |
elif ctrl.value < 64 and manager is not None: | |
# sustain up | |
manager.end = ctrl.time | |
sustains.append(manager) | |
manager = None | |
elif ctrl.value < 64 and len(sustains) > 0: | |
sustains[-1].end = ctrl.time | |
return sustains | |
def _note_preprocess(susteins, notes): | |
note_stream = [] | |
count_note_processed = 0 | |
if susteins: # if the midi file has sustain controls | |
for sustain in susteins: | |
if len(notes) > 0: | |
for note_idx, note in enumerate(notes): | |
if note.start < sustain.start: | |
note_stream.append(note) | |
last_counted = True | |
elif note.start > sustain.end: | |
# notes = notes[note_idx:] | |
# sustain.transposition_notes() | |
last_counted = False | |
break | |
else: | |
sustain.add_managed_note(note) | |
last_counted = True | |
count_note_processed += 1 | |
sustain.transposition_notes() # transpose what in the sustain | |
note_stream += sustain.managed_notes # add to stream | |
# remove notes that were already added to the stream | |
last_idx = note_idx if not last_counted else note_idx + 1 | |
if last_idx < len(notes): | |
notes = notes[last_idx:] # save next notes, previous notes were stored in note stream | |
else: | |
notes = [] | |
note_stream += notes | |
count_note_processed += len(notes) | |
else: # else, just push everything into note stream | |
for note_idx, note in enumerate(notes): | |
note_stream.append(note) | |
note_stream.sort(key= lambda x: x.start) | |
return note_stream | |
def midi_valid(midi) -> bool: | |
# if any(ts.numerator != 4 or ts.denominator != 4 for ts in midi.time_signature_changes): | |
# return False # time signature different from 4/4 | |
# if midi.max_tick < 10 * midi.ticks_per_beat: | |
# return False # this MIDI is too short | |
return True | |
def encode_midi_structured(file_path, nb_aug, nb_noise): | |
notes = [] | |
mid = MidiFile(file_path) | |
assert midi_valid(mid) | |
# Converts MIDI to tokens, and back to a MIDI | |
for inst in mid.instruments: | |
inst_notes = inst.notes | |
# ctrl.number is the number of sustain control. If you want to know abour the number type of control, | |
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 | |
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) | |
notes += _note_preprocess(ctrls, inst_notes) | |
assert len(notes) == len(mid.instruments[0].notes) | |
# sort notes | |
arg_rank = np.argsort([n.start for n in notes]) | |
notes = list(np.array(notes)[arg_rank]) | |
original_notes = deepcopy(notes) | |
# convert notes to ints | |
encoded_main = tokenizer_structured.midi_to_tokens(mid)[0] | |
min_pitch = np.min([n.pitch for n in notes]) | |
encoded_augmentations = [] | |
noise_shift = 6 | |
aug_shift = 3 | |
embedding_noise = None | |
for i_aug in range(nb_aug): | |
a_notes = alter_notes_exact_tick(original_notes, aug_shift, min_pitch) | |
mid.instruments[0].notes = a_notes | |
assert midi_valid(mid) | |
embedding_aug = tokenizer_structured.midi_to_tokens(mid)[0] # encode notes | |
encoded_augmentations.append(embedding_aug) | |
if nb_noise > 0: | |
a_notes = alter_notes_exact_tick(original_notes, noise_shift, min_pitch) | |
mid.instruments[0].notes = a_notes | |
assert midi_valid(mid) | |
embedding_noise = tokenizer_structured.midi_to_tokens(mid)[0] # encode notes | |
return encoded_main, encoded_augmentations, embedding_noise | |
def encode_midi_cp(file_path, nb_aug, nb_noise): | |
notes = [] | |
mid = MidiFile(file_path) | |
assert midi_valid(mid) | |
# Converts MIDI to tokens, and back to a MIDI | |
for inst in mid.instruments: | |
inst_notes = inst.notes | |
# ctrl.number is the number of sustain control. If you want to know abour the number type of control, | |
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 | |
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) | |
notes += _note_preprocess(ctrls, inst_notes) | |
assert len(notes) == len(mid.instruments[0].notes) | |
# sort notes | |
arg_rank = np.argsort([n.start for n in notes]) | |
notes = list(np.array(notes)[arg_rank]) | |
original_notes = deepcopy(notes) | |
# convert notes to ints | |
encoded_main = tokenizer_cp.midi_to_tokens(mid)[0] | |
min_pitch = np.min([n.pitch for n in notes]) | |
encoded_augmentations = [] | |
noise_shift = 6 | |
aug_shift = 3 | |
embedding_noise = None | |
for i_aug in range(nb_aug): | |
a_notes = alter_notes_exact_tick(original_notes, aug_shift, min_pitch) | |
mid.instruments[0].notes = a_notes | |
assert midi_valid(mid) | |
embedding_aug = tokenizer_cp.midi_to_tokens(mid)[0] # encode notes | |
encoded_augmentations.append(embedding_aug) | |
if nb_noise > 0: | |
a_notes = alter_notes_exact_tick(original_notes, noise_shift, min_pitch) | |
mid.instruments[0].notes = a_notes | |
assert midi_valid(mid) | |
embedding_noise = tokenizer_cp.midi_to_tokens(mid)[0] # encode notes | |
return encoded_main, encoded_augmentations, embedding_noise | |
def alter_notes_exact_tick(notes, shift, min_pitch): | |
# copy original notes | |
a_notes = deepcopy(notes) | |
# sample smart augmentation | |
pitch_shift, time_scaling = 0, 0 | |
while pitch_shift == 0 and time_scaling == 0: | |
pitch_shift = np.random.choice(np.arange(max(-shift, -min_pitch), shift+1)) | |
time_scaling = np.random.choice([-5, -2.5, 0, 2.5, 5]) | |
assert pitch_shift <= shift and pitch_shift >= -shift | |
# modify notes | |
for e in a_notes: | |
e.start = int(e.start * (1. + time_scaling / 100)) | |
e.end = int(e.end * (1. + time_scaling / 100)) | |
new_pitch = max(e.pitch + pitch_shift, 0) | |
e.pitch = new_pitch | |
return a_notes | |
def alter_notes(notes, shift, min_pitch): | |
# copy original notes | |
a_notes = deepcopy(notes) | |
# sample smart augmentation | |
pitch_shift, time_scaling = 0, 0 | |
while pitch_shift == 0 and time_scaling == 0: | |
pitch_shift = np.random.choice(np.arange(max(-shift, -min_pitch), shift+1)) | |
time_scaling = np.random.choice([-5, -2.5, 0, 2.5, 5]) | |
assert pitch_shift <= shift and pitch_shift >= -shift | |
# modify notes | |
for e in a_notes: | |
e.start = e.start * (1. + time_scaling / 100) | |
e.end = e.end * (1. + time_scaling / 100) | |
new_pitch = max(e.pitch + pitch_shift, 0) | |
e.pitch = new_pitch | |
return a_notes | |
def encode_midi(file_path, nb_aug, nb_noise): | |
notes = [] | |
mid = pretty_midi.PrettyMIDI(midi_file=file_path) | |
for inst in mid.instruments: | |
inst_notes = inst.notes | |
# ctrl.number is the number of sustain control. If you want to know abour the number type of control, | |
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 | |
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) | |
notes += _note_preprocess(ctrls, inst_notes) | |
assert len(notes) == len(mid.instruments[0].notes) | |
# sort notes | |
arg_rank = np.argsort([n.start for n in notes]) | |
notes = list(np.array(notes)[arg_rank]) | |
# convert notes to ints | |
encoded_main = convert_notes(notes) | |
min_pitch = np.min([n.pitch for n in notes]) | |
encoded_augmentations = [] | |
noise_shift = 6 | |
aug_shift = 3 | |
embedding_noise = None | |
for i_aug in range(nb_aug): | |
a_notes = alter_notes(notes, aug_shift, min_pitch) | |
embedding_group = convert_notes(a_notes) # encode notes | |
encoded_augmentations.append(embedding_group) | |
if nb_noise > 0: | |
a_notes = alter_notes(notes, noise_shift, min_pitch) | |
embedding_noise = convert_notes(a_notes) # encode notes | |
return encoded_main, encoded_augmentations, embedding_noise | |
def chunk_notes(n_notes_per_chunk, notes): | |
index = 0 | |
chunks = [] | |
for n in n_notes_per_chunk: | |
chunks.append(notes[index:index+n]) | |
index += n | |
return chunks | |
def chunk_first_embedding(chunk_size, embedding): | |
chunks = [] | |
index = 0 | |
if len(embedding) < chunk_size: | |
return [embedding] | |
else: | |
for i in range(chunk_size, len(embedding) + chunk_size, chunk_size): | |
if (len(embedding) - index) > (chunk_size / 2): | |
chunks.append(embedding[index:i]) | |
index = i | |
return chunks | |
def encode_midi_in_chunks(file_path, n_aug, n_noise): | |
n_noise = 0 | |
notes = [] | |
mid = pretty_midi.PrettyMIDI(midi_file=file_path) | |
# preprocess midi | |
for inst in mid.instruments: | |
inst_notes = inst.notes | |
# ctrl.number is the number of sustain control. If you want to know abour the number type of control, | |
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 | |
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) | |
notes += _note_preprocess(ctrls, inst_notes) | |
assert len(notes) == len(mid.instruments[0].notes) | |
arg_rank = np.argsort([n.start for n in notes]) | |
notes = list(np.array(notes)[arg_rank]) | |
# convert notes to ints | |
main_embedding = convert_notes(notes) | |
# split the sequence of events in chunks | |
if np.max(main_embedding) < MAX_EMBEDDING and np.min(main_embedding) >= 0: | |
encoded_chunks = chunk_first_embedding(CHUNK_SIZE, main_embedding) | |
else: | |
assert False | |
n_notes_per_chunk = [np.argwhere(np.array(ec) < 128).flatten().size for ec in encoded_chunks] | |
chunked_notes = chunk_notes(n_notes_per_chunk, notes) | |
# reencode chunks by shifting notes | |
encoded_chunks = [] | |
for note_group in chunked_notes: | |
note_group = shift_notes(note_group) | |
embedding_main = convert_notes(note_group)[:CHUNK_SIZE] | |
encoded_chunks.append(embedding_main) | |
min_pitches = [np.min([n.pitch for n in cn]) for cn in chunked_notes] | |
encoded_augmentations = [] | |
aug_shift = 3 | |
for i_aug in range(n_aug): | |
chunked_embedding_aug = [] | |
for note_group, min_pitch in zip(chunked_notes, min_pitches): | |
a_notes = alter_notes(note_group, aug_shift, min_pitch) | |
a_notes = shift_notes(a_notes) | |
assert len(a_notes) == len(note_group) | |
embedding_group = convert_notes(a_notes)[:CHUNK_SIZE] # encode notes | |
chunked_embedding_aug.append(embedding_group) | |
encoded_augmentations += chunked_embedding_aug | |
assert len(encoded_augmentations) == n_aug * len(encoded_chunks) | |
return encoded_chunks, encoded_augmentations, [] | |
def encode_miditok_in_chunks(file_path, n_aug, n_noise): | |
n_noise = 0 | |
notes = [] | |
mid = MidiFile(file_path) | |
assert midi_valid(mid) | |
# Converts MIDI to tokens, and back to a MIDI | |
for inst in mid.instruments: | |
inst_notes = inst.notes | |
# ctrl.number is the number of sustain control. If you want to know abour the number type of control, | |
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 | |
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) | |
notes += _note_preprocess(ctrls, inst_notes) | |
assert len(notes) == len(mid.instruments[0].notes) | |
# sort notes | |
arg_rank = np.argsort([n.start for n in notes]) | |
notes = list(np.array(notes)[arg_rank]) | |
# convert notes to ints | |
encoded_main = tokenizer_cp.midi_to_tokens(mid)[0] | |
encoded_chunks = chunk_first_embedding(CHUNK_SIZE, encoded_main) | |
n_notes_per_chunk = [len([tokenizer_cp.vocab.token_to_event[e[0]] for e in enc_chunk if tokenizer_cp.vocab.token_to_event[e[0]] == 'Family_Note']) | |
for enc_chunk in encoded_chunks] | |
chunked_notes = chunk_notes(n_notes_per_chunk, notes) | |
# reencode chunks by shifting notes | |
encoded_chunks = [] | |
for note_group in chunked_notes: | |
mid.instruments[0].notes = note_group | |
mid = shift_mid(mid) # shift midi | |
assert midi_valid(mid) | |
embedding_main = tokenizer_cp.midi_to_tokens(mid)[0][:CHUNK_SIZE] # tokenize midi | |
encoded_chunks.append(embedding_main) | |
min_pitch = np.min([n.pitch for n in notes]) | |
encoded_augmentations = [] | |
aug_shift = 3 | |
for i_aug in range(n_aug): | |
chunked_embedding_aug = [] | |
for note_group in chunked_notes: | |
a_notes = alter_notes_exact_tick(note_group, aug_shift, min_pitch) | |
assert len(a_notes) == len(note_group) | |
mid.instruments[0].notes = a_notes | |
# shift midi | |
mid = shift_mid(mid) | |
assert midi_valid(mid) | |
# tokenize midi | |
embedding_aug = tokenizer_cp.midi_to_tokens(mid)[0][:CHUNK_SIZE] # encode notes | |
chunked_embedding_aug.append(embedding_aug) | |
encoded_augmentations += chunked_embedding_aug | |
assert len(encoded_augmentations) == n_aug * len(encoded_chunks) | |
return encoded_chunks, encoded_augmentations, [] | |
def encode_midi_chunks_structured(file_path, n_aug, n_noise): | |
n_noise = 0 | |
notes = [] | |
mid = MidiFile(file_path) | |
assert midi_valid(mid) | |
# Converts MIDI to tokens, and back to a MIDI | |
for inst in mid.instruments: | |
inst_notes = inst.notes | |
# ctrl.number is the number of sustain control. If you want to know abour the number type of control, | |
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 | |
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) | |
notes += _note_preprocess(ctrls, inst_notes) | |
assert len(notes) == len(mid.instruments[0].notes) | |
nb_notes = CHUNK_SIZE // 4 | |
notes = notes[:50 * nb_notes] # limit to 50 chunks to speed up | |
# sort notes | |
arg_rank = np.argsort([n.start for n in notes]) | |
notes = list(np.array(notes)[arg_rank]) | |
assert (len(notes) // nb_notes) > 1 # assert at least 3 chunks | |
n_notes_per_chunk = [nb_notes for _ in range(len(notes) // nb_notes)] | |
if len(notes) % nb_notes > nb_notes / 2: | |
n_notes_per_chunk.append(len(notes) % nb_notes) | |
chunked_notes = chunk_notes(n_notes_per_chunk, notes) | |
# reencode chunks by shifting notes | |
encoded_chunks = [] | |
for note_group in chunked_notes: | |
mid.instruments[0].notes = note_group | |
mid = shift_mid(mid) # shift midi | |
assert midi_valid(mid) | |
embedding_main = tokenizer_structured.midi_to_tokens(mid)[0] # tokenize midi | |
encoded_chunks.append(embedding_main) | |
min_pitch = np.min([n.pitch for n in notes]) | |
encoded_augmentations = [] | |
aug_shift = 3 | |
for i_aug in range(n_aug): | |
chunked_embedding_aug = [] | |
for note_group in chunked_notes: | |
a_notes = alter_notes_exact_tick(note_group, aug_shift, min_pitch) | |
assert len(a_notes) == len(note_group) | |
mid.instruments[0].notes = a_notes | |
# shift midi | |
mid = shift_mid(mid) | |
assert midi_valid(mid) | |
# tokenize midi | |
embedding_aug = tokenizer_structured.midi_to_tokens(mid)[0] # encode notes | |
chunked_embedding_aug.append(embedding_aug) | |
encoded_augmentations += chunked_embedding_aug | |
assert len(encoded_augmentations) == n_aug * len(encoded_chunks) | |
return encoded_chunks, encoded_augmentations, [] | |
def shift_mid(mid): | |
# mid = deepcopy(mid) | |
to_remove = mid.instruments[0].notes[0].start | |
if to_remove > 0: | |
for n in mid.instruments[0].notes: | |
n.start -= to_remove | |
n.end -= to_remove | |
# for e in mid.tempo_changes: | |
# e.time = max(0, e.time - to_remove) | |
# | |
# for e in mid.time_signature_changes: | |
# e.time = max(0, e.time - to_remove) | |
# | |
# for e in mid.key_signature_changes: | |
# e.time = max(0, e.time - to_remove) | |
return mid | |
def shift_notes(notes): | |
to_remove = notes[0].start | |
for n in notes: | |
n.start -= to_remove | |
n.end -= to_remove | |
return notes | |
def convert_notes(notes): | |
events = [] | |
dnotes = _divide_note(notes) # split events in on / off | |
# print(dnotes) | |
dnotes.sort(key=lambda x: x.time) | |
# print('sorted:') | |
# print(dnotes) | |
cur_time = 0 | |
cur_vel = 0 | |
for snote in dnotes: | |
events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time) | |
events += _snote2events(snote=snote, prev_vel=cur_vel) | |
# events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time) | |
cur_time = snote.time | |
cur_vel = snote.velocity | |
event_list = [e.to_int() for e in events] | |
if not (np.max(event_list) < MAX_EMBEDDING and np.min(event_list) >= 0): | |
print('weird') | |
assert False | |
return event_list | |
def decode_midi_structured(encoding, file_path=None): | |
mid = tokenizer_structured.tokens_to_midi([encoding]) | |
if file_path: | |
mid.dump(file_path) | |
return mid | |
def decode_midi_cp(encoding, file_path=None): | |
mid = tokenizer_cp.tokens_to_midi([encoding]) | |
if file_path: | |
mid.dump(file_path) | |
return mid | |
def decode_midi(idx_array, file_path=None): | |
event_sequence = [Event.from_int(idx) for idx in idx_array] | |
# print(event_sequence) | |
snote_seq = _event_seq2snote_seq(event_sequence) | |
note_seq = _merge_note(snote_seq) | |
note_seq.sort(key=lambda x:x.start) | |
mid = pretty_midi.PrettyMIDI() | |
# if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set | |
instument = pretty_midi.Instrument(1, False, "Developed By Yang-Kichang") | |
instument.notes = note_seq | |
mid.instruments.append(instument) | |
if file_path is not None: | |
mid.write(file_path) | |
return mid | |
if __name__ == '__main__': | |
encoded = encode_midi('bin/ADIG04.mid') | |
print(encoded) | |
decided = decode_midi(encoded,file_path='bin/test.mid') | |
ins = pretty_midi.PrettyMIDI('bin/ADIG04.mid') | |
print(ins) | |
print(ins.instruments[0]) | |
for i in ins.instruments: | |
print(i.control_changes) | |
print(i.notes) | |