Created
April 20, 2024 02:54
-
-
Save RoadrunnerWMC/c047f7841dc5e50f35013b5d5e42c532 to your computer and use it in GitHub Desktop.
A Python script to find the maximum number of notes that ever play simultaneously in a MIDI file, across all tracks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import argparse | |
from pathlib import Path | |
import mido | |
NoteList = list[tuple[int, str, int]] | |
def squash_track_notes(midi: mido.MidiFile) -> NoteList: | |
""" | |
Create a single list that just contains (tick, 'start'/'end', some_unique_id) | |
triples for every note in the midi, sorted by tick. | |
""" | |
_next_id = 1 | |
def make_id(): | |
nonlocal _next_id | |
_next_id += 1 | |
return _next_id - 1 | |
all_notes = [] | |
for track_i, track in enumerate(midi.tracks): | |
active_notes = {} | |
time = 0 | |
for msg in track: | |
time += msg.time | |
if msg.type == 'note_on' and msg.velocity > 0: | |
local_id = (msg.channel, msg.note) | |
if local_id in active_notes: | |
print(f'WARNING: duplicate note ({msg}) at {time} on track {track_i + 1}') | |
active_notes[local_id] = time | |
elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0): | |
local_id = (msg.channel, msg.note) | |
if local_id not in active_notes: | |
print(f'WARNING: start of note ({msg}) at {time} on track {track_i + 1} was never seen') | |
start_time = active_notes[local_id] | |
global_id = make_id() | |
all_notes.append((start_time, 'start', global_id)) | |
all_notes.append((time, 'end', global_id)) | |
del active_notes[local_id] | |
if active_notes: | |
print(f'WARNING: some notes in track {track_i + 1} never ended') | |
all_notes.sort(key=lambda elem: elem[0]) | |
return all_notes | |
def find_times_with_most_active_notes(all_notes: NoteList) -> tuple[list[int], int]: | |
# Just iterate over the list and keep track of the time when the most notes are active | |
active_notes = set() | |
worst_times = [] | |
worst_amount = 0 | |
for time, type_, id_ in all_notes: | |
if type_ == 'start': | |
active_notes.add(id_) | |
else: | |
active_notes.remove(id_) | |
if len(active_notes) > worst_amount: | |
worst_amount = len(active_notes) | |
worst_times = [time] | |
elif len(active_notes) == worst_amount: | |
worst_times.append(time) | |
if active_notes: | |
print('WARNING: some notes in the squashed track never ended') | |
return worst_times, worst_amount | |
def main(argv: list[str] | None = None) -> None: | |
parser = argparse.ArgumentParser( | |
description='MIDI Simultaneous Note Counter') | |
parser.add_argument('midi_file', type=Path, | |
help='input file to check') | |
parser.add_argument('--beats-per-measure', type=float, default=4, | |
help='number of beats per measure (default: 4, meaning 4/4 time)') | |
parser.add_argument('--extend-notes-by', type=float, default=16, | |
help="treat all notes as if they're this much longer, to account for fade-out duration (default: 16, meaning the duration of a sixteenth note)") | |
args = parser.parse_args() | |
print(f'Checking {args.midi_file}...') | |
midi = mido.MidiFile(args.midi_file) | |
quarter_note = midi.ticks_per_beat | |
if args.extend_notes_by <= 0: | |
note_time_extension = 0 | |
else: | |
note_time_extension = round(quarter_note / (args.extend_notes_by / 4)) | |
# First, conceptually squash everything into a single track | |
all_notes = squash_track_notes(midi) | |
# Lengthen all notes slightly, to allow them time to fade out | |
if note_time_extension > 0: | |
all_notes_2 = [] | |
for time, type_, id_ in all_notes: | |
if type_ == 'end': | |
time += note_time_extension | |
all_notes_2.append((time, type_, id_)) | |
all_notes = all_notes_2 | |
del all_notes_2 | |
all_notes.sort(key=lambda elem: elem[0]) | |
# Then find the times with the most notes active | |
worst_times, worst_amount = find_times_with_most_active_notes(all_notes) | |
# And finally, print them out | |
print(f'{worst_amount} notes are playing simultaneously at:') | |
for time in worst_times: | |
beat = time / midi.ticks_per_beat | |
measure = int(beat / 4) | |
beat_in_measure = beat % args.beats_per_measure | |
print(f'- tick {time} (beat {beat + 1:0.3f}, aka measure {measure + 1} beat {beat_in_measure + 1:0.3f})') | |
print(f'(Ticks are zero-indexed. Beats and measures are one-indexed, and assume a {args.beats_per_measure}/4 time signature.)') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment