| import os,csv
|
|
|
| from argparse import ArgumentParser
|
| from glob import glob
|
|
|
| import numpy as np
|
|
|
| from tqdm import tqdm
|
|
|
| from anticipation import ops
|
| from anticipation.visuals import visualize
|
| from anticipation.tokenize import extract_instruments
|
| from anticipation.convert import midi_to_events, events_to_midi
|
| from anticipation.config import TIME_RESOLUTION
|
| from anticipation.vocab import TIME_OFFSET, NOTE_OFFSET
|
|
|
| def select_sample(filenames, prompt_length, clip_length, verbose=False):
|
| while True:
|
|
|
| idx = np.random.randint(len(filenames))
|
| if verbose:
|
| print('Loading index: ', idx)
|
|
|
| try:
|
| events = midi_to_events(filenames[idx])
|
| except Exception:
|
| continue
|
|
|
| max_time = ops.max_time(events) - clip_length
|
|
|
|
|
| if max_time < 0:
|
| if verbose:
|
| print(f' rejected: track is too short (length {ops.max_time(events)} < {clip_length})')
|
| continue
|
|
|
| start_time = max_time*np.random.rand(1)[0]
|
| clip = ops.clip(events, start_time, start_time+clip_length, clip_duration=True)
|
| clip = ops.translate(clip, -int(TIME_RESOLUTION*start_time))
|
|
|
|
|
| instruments = [instr for instr in ops.get_instruments(clip).keys() if instr != 128]
|
| if len(instruments) < 4 or len(instruments) > 10:
|
| if verbose:
|
| print(f' rejected: track instrument count out of bounds: {len(instruments)}')
|
| continue
|
|
|
|
|
| pitches = {}
|
| for instr in ops.get_instruments(clip).keys():
|
| pitches[instr] = []
|
|
|
| for time, _, note in zip(clip[0::3],clip[1::3],clip[2::3]):
|
| time -= TIME_OFFSET
|
| note -= NOTE_OFFSET
|
|
|
| instr = note//2**7
|
| pitch = note - (2**7)*instr
|
|
|
| pitches[instr].append(pitch)
|
|
|
| melody = None
|
| high = 0
|
| for instr in ops.get_instruments(clip).keys():
|
| if instr in [0,9] + list(range(112,129)):
|
| continue
|
|
|
| avg = np.mean(pitches[instr])
|
| if avg > high:
|
| melody = instr
|
| high = avg
|
|
|
| assert melody
|
|
|
|
|
| if ops.get_instruments(clip)[melody] < 20:
|
| if verbose:
|
| print(' rejected: too few melodic notes')
|
| continue
|
|
|
|
|
| if ops.min_time(clip, seconds=True, instr=melody) > prompt_length:
|
| if verbose:
|
| print(' rejected: prompt does not contain the melody')
|
| continue
|
|
|
|
|
| if ops.max_time(clip, seconds=True, instr=melody) < (clip_length-2):
|
| if verbose:
|
| print(' rejected: melody ends before the end of the clip')
|
| continue
|
|
|
| break
|
|
|
| return os.path.basename(filenames[idx]), clip, melody
|
|
|
|
|
| def main(args):
|
| np.random.seed(args.seed)
|
|
|
| print(f'Selecting clips for accompaniment from: {args.dir}')
|
| filenames = glob(args.dir + '/**/*.mid', recursive=True) \
|
| + glob(args.dir + '/**/*.midi', recursive=True)
|
| filenames = sorted(filenames)
|
|
|
| print(f'Saving clips to: {args.output}')
|
| try:
|
| os.makedirs(args.output)
|
| except FileExistsError:
|
| pass
|
|
|
| try:
|
| os.makedirs(f'{args.output}/groundtruth')
|
| except FileExistsError:
|
| pass
|
|
|
| with open(f'{args.output}/index.csv', 'w', newline='') as f:
|
| writer = csv.writer(f)
|
| writer.writerow(['idx', 'original', 'conditional', 'parts', 'melody'])
|
|
|
| for i in tqdm(range(args.count)):
|
| filename, clip, melody = select_sample(filenames, args.prompt_length, args.clip_length)
|
| parts = ops.get_instruments(clip).keys()
|
| writer.writerow([i, filename, f'{i}-conditional.mid', len(parts), melody])
|
|
|
| mid = events_to_midi(clip)
|
| mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
|
| if args.visualize:
|
| visualize(clip, f'{args.output}/groundtruth/{i}-clip.png')
|
|
|
| events, controls = extract_instruments(clip, [melody])
|
| prompt = ops.clip(events, 0, args.prompt_length, clip_duration=False)
|
|
|
| conditional_events = ops.clip(ops.combine(prompt, controls), 0, args.clip_length)
|
| mid = events_to_midi(conditional_events)
|
| mid.save(f'{args.output}/{i}-conditional.mid')
|
| if args.visualize:
|
| visualize(conditional_events, f'{args.output}/{i}-conditional.png')
|
|
|
|
|
| if __name__ == '__main__':
|
| parser = ArgumentParser(description='select prompts for infilling completion human eval')
|
| parser.add_argument('dir', help='directory containing MIDI files to sample')
|
| parser.add_argument('-o', '--output', type=str, default='output',
|
| help='output directory')
|
| parser.add_argument('-s', '--seed', type=int, default=0,
|
| help='random seed for sampling')
|
| parser.add_argument('-c', '--count', type=int, default=10,
|
| help='number of clips to sample')
|
| parser.add_argument('-p', '--prompt_length', type=int, default=5,
|
| help='length of the prompt (in seconds)')
|
| parser.add_argument('-l', '--clip_length', type=int, default=20,
|
| help='length of the full clip (in seconds)')
|
| parser.add_argument('-v', '--visualize', action='store_true',
|
| help='plot visualizations')
|
| main(parser.parse_args())
|
|
|