| import os, csv, time
|
|
|
| from argparse import ArgumentParser
|
|
|
| import numpy as np
|
|
|
| import torch
|
|
|
| from transformers import AutoModelForCausalLM
|
|
|
| from anticipation import ops
|
| from anticipation.visuals import visualize
|
| from anticipation.convert import midi_to_interarrival, interarrival_to_midi
|
| from anticipation.convert import midi_to_events, events_to_midi
|
| from anticipation.vocab import MIDI_SEPARATOR,MIDI_START_OFFSET,MIDI_END_OFFSET
|
|
|
|
|
| def main(args):
|
| np.random.seed(args.seed)
|
|
|
| print(f'Prompting using model checkpoint: {args.model}')
|
| t0 = time.time()
|
| model = AutoModelForCausalLM.from_pretrained(args.model).cuda()
|
| print(f'Loaded model ({time.time()-t0} seconds)')
|
|
|
| print(f'Writing outputs to {args.dir}/{args.output}')
|
| try:
|
| os.makedirs(f'{args.dir}/{args.output}')
|
| except FileExistsError:
|
| pass
|
|
|
| print(f'Prompting with tracks in index : {args.dir}/index.csv')
|
| with open(f'{args.dir}/index.csv', newline='') as f:
|
| reader = csv.reader(f)
|
| header = next(reader)
|
| for row in reader:
|
| prompt_midi = row[header.index('prompt')]
|
| idx = int(row[header.index('idx')])
|
|
|
| prompt = midi_to_interarrival(os.path.join(args.dir, prompt_midi))
|
|
|
| max_idx = 0
|
| for i,token in enumerate(prompt):
|
| if MIDI_START_OFFSET <= token < MIDI_START_OFFSET + MIDI_END_OFFSET:
|
| max_idx = i
|
|
|
| prompt = prompt[:max_idx+1]
|
| for j in range(args.multiplicity):
|
| t0 = time.time()
|
|
|
| input_ids = torch.tensor([prompt]).cuda()
|
| output = model.generate(input_ids, do_sample=True, max_length=1024, top_p=0.95, pad_token_id=MIDI_SEPARATOR)
|
| output = output[0].cpu().tolist()
|
|
|
|
|
| mid = interarrival_to_midi(output)
|
| events = midi_to_events(mid)
|
| output = ops.clip(events, 0, args.clip_length)
|
| mid = events_to_midi(output)
|
| mid.save(f'{args.dir}/{args.output}/{idx}-clip-v{j}.mid')
|
| if args.visualize:
|
| visualize(output, f'{args.dir}/{args.output}/{idx}-clip-v{j}.png')
|
|
|
| print(f'Generated completion. Sampling time: {time.time()-t0} seconds')
|
|
|
|
|
| if __name__ == '__main__':
|
| parser = ArgumentParser(description='generate prompted completions with an interarrival-time model')
|
| parser.add_argument('dir', help='directory containing an index of MIDI files')
|
| parser.add_argument('model', help='directory containing an interarrival model checkpoint')
|
| parser.add_argument('-o', '--output', type=str, default='model',
|
| help='model description (the name of the output subdirectory)')
|
| parser.add_argument('-s', '--seed', type=int, default=0,
|
| help='random seed')
|
| parser.add_argument('-c', '--count', type=int, default=10,
|
| help='number of clips to sample')
|
| parser.add_argument('-m', '--multiplicity', type=int, default=1,
|
| help='number of generations per clip')
|
| parser.add_argument('-l', '--clip_length', type=int, default=20,
|
| help='length of the full clip (in seconds)')
|
| parser.add_argument('-v', '--visualize', action='store_true',
|
| help='plot visualizations')
|
| main(parser.parse_args())
|
|
|