Source code for jdxi_editor.devel.analyze_drum_track

#!/usr/bin/env python3
"""
Analyze a MIDI file to identify the drum track.

This script uses multiple heuristics to determine which track is the drum track:
1. MIDI Channel 9 (standard GM drum channel)
2. Track name containing drum-related keywords
3. Note patterns (drum range, note density, polyphony)
4. Lack of pitch bend/control changes
"""

import os
import sys
from collections import defaultdict
from typing import List, Optional, Tuple

from picomidi.message.type import MidoMessageType

try:
    import mido
    from mido import MidiFile
except ImportError:
    print("ERROR: mido library not found. Install with: pip install mido")
    sys.exit(1)


# Standard General MIDI drum note range (35-81)
[docs] DRUM_NOTE_MIN = 35
[docs] DRUM_NOTE_MAX = 81
# Keywords that suggest a drum track
[docs] DRUM_KEYWORDS = [ "drum", "percussion", "perc", "kit", "beat", "rhythm", "snare", "kick", "hihat", "hi-hat", "cymbal", "crash", ]
[docs] def analyze_track_for_drums(track: mido.MidiTrack, track_index: int) -> dict: """ Analyze a MIDI track to determine if it's likely a drum track. Returns a dictionary with analysis results and a score. """ analysis = { "track_index": track_index, "track_name": None, "channels": set(), "note_count": 0, "drum_note_count": 0, "note_ons": [], "note_offs": [], "avg_note_duration": 0.0, "simultaneous_notes": 0, "max_simultaneous": 0, "has_pitch_bend": False, "has_control_change": False, "program_changes": [], "score": 0.0, } # Extract track name from meta messages for msg in track: if msg.is_meta and msg.type == "track_name": analysis["track_name"] = msg.name break # Analyze messages active_notes = defaultdict(int) # Track active notes at each tick absolute_time = 0 note_durations = [] for msg in track: absolute_time += msg.time if msg.is_meta: continue if hasattr(msg, "channel"): analysis["channels"].add(msg.channel) if msg.type == MidoMessageType.NOTE_ON.value and msg.velocity > 0: analysis["note_count"] += 1 analysis["note_ons"].append((absolute_time, msg.note, msg.channel)) # Check if note is in drum range if DRUM_NOTE_MIN <= msg.note <= DRUM_NOTE_MAX: analysis["drum_note_count"] += 1 # Track simultaneous notes active_notes[absolute_time] += 1 analysis["max_simultaneous"] = max( analysis["max_simultaneous"], active_notes[absolute_time] ) elif msg.type == MidoMessageType.NOTE_OFF.value or ( msg.type == MidoMessageType.NOTE_ON.value and msg.velocity == 0 ): analysis["note_offs"].append((absolute_time, msg.note)) elif msg.type == MidoMessageType.PITCH_WHEEL: analysis["has_pitch_bend"] = True elif msg.type == MidoMessageType.CONTROL_CHANGE.value: analysis["has_control_change"] = True elif msg.type == MidoMessageType.PROGRAM_CHANGE.value: analysis["program_changes"].append(msg.program) # Calculate average note duration (simplified) if analysis["note_ons"] and analysis["note_offs"]: # Match note ons with note offs (simplified approach) durations = [] for on_time, note, _ in analysis["note_ons"]: # Find corresponding note off for off_time, off_note in analysis["note_offs"]: if off_note == note and off_time > on_time: durations.append(off_time - on_time) break if durations: analysis["avg_note_duration"] = sum(durations) / len(durations) # Calculate score score = 0.0 # Channel 9 (drum channel) = +50 points if 9 in analysis["channels"]: score += 50.0 # Track name contains drum keywords = +30 points if analysis["track_name"]: name_lower = analysis["track_name"].lower() if any(keyword in name_lower for keyword in DRUM_KEYWORDS): score += 30.0 # High percentage of drum notes = +20 points if analysis["note_count"] > 0: drum_percentage = (analysis["drum_note_count"] / analysis["note_count"]) * 100 if drum_percentage > 80: score += 20.0 elif drum_percentage > 50: score += 10.0 # High note density = +10 points if analysis["note_count"] > 100: score += 10.0 # High polyphony (many simultaneous notes) = +10 points if analysis["max_simultaneous"] > 5: score += 10.0 # Short average note duration (typical for drums) = +5 points if analysis["avg_note_duration"] > 0 and analysis["avg_note_duration"] < 500: score += 5.0 # No pitch bend (drums don't use it) = +5 points if not analysis["has_pitch_bend"]: score += 5.0 analysis["score"] = score return analysis
[docs] def find_drum_track(midi_file_path: str) -> Optional[Tuple[int, dict]]: """ Analyze a MIDI file and return the most likely drum track. Returns: (track_index, analysis_dict) or None if file can't be analyzed """ if not os.path.exists(midi_file_path): print(f"ERROR: File not found: {midi_file_path}") return None try: midi_file = MidiFile(midi_file_path) except Exception as e: print(f"ERROR: Could not load MIDI file: {e}") return None print(f"\n{'='*60}") print(f"Analyzing MIDI file: {midi_file_path}") print(f"{'='*60}") print(f"Ticks per beat: {midi_file.ticks_per_beat}") print(f"Number of tracks: {len(midi_file.tracks)}") print(f"Type: {midi_file.type}") print() # Analyze all tracks track_analyses = [] for i, track in enumerate(midi_file.tracks): analysis = analyze_track_for_drums(track, i) track_analyses.append(analysis) # Print analysis for all tracks print("Track Analysis:") print("-" * 60) for analysis in track_analyses: track_name = analysis["track_name"] or f"Track {analysis['track_index'] + 1}" channels = sorted(analysis["channels"]) channels_str = ", ".join(str(ch + 1) for ch in channels) if channels else "None" print(f"\nTrack {analysis['track_index'] + 1}: {track_name}") print(f" Channels: {channels_str}") print(f" Total notes: {analysis['note_count']}") print( f" Drum notes ({DRUM_NOTE_MIN}-{DRUM_NOTE_MAX}): {analysis['drum_note_count']}" ) if analysis["note_count"] > 0: drum_pct = (analysis["drum_note_count"] / analysis["note_count"]) * 100 print(f" Drum note percentage: {drum_pct:.1f}%") print(f" Max simultaneous notes: {analysis['max_simultaneous']}") if analysis["avg_note_duration"] > 0: print(f" Avg note duration: {analysis['avg_note_duration']:.1f} ticks") print(f" Has pitch bend: {analysis['has_pitch_bend']}") print(f" Has control changes: {analysis['has_control_change']}") if analysis["program_changes"]: print(f" Program changes: {analysis['program_changes']}") print(f" DRUM SCORE: {analysis['score']:.1f}") # Find track with highest score if not track_analyses: return None # Sort by score (descending), then by tie-breaker criteria def tie_breaker(analysis): """Tie-breaker: prefer tracks with more notes, higher polyphony, shorter durations""" return ( analysis["score"], analysis["note_count"], # More notes = more likely main drum track analysis["max_simultaneous"], # Higher polyphony ( -analysis["avg_note_duration"] if analysis["avg_note_duration"] > 0 else 0 ), # Shorter durations not analysis["has_control_change"], # No control changes preferred ) best_track = max(track_analyses, key=tie_breaker) # Check for ties top_score = best_track["score"] tied_tracks = [t for t in track_analyses if t["score"] == top_score] print(f"\n{'='*60}") print("RESULT:") print(f"{'='*60}") if best_track["score"] > 0: track_name = ( best_track["track_name"] or f"Track {best_track['track_index'] + 1}" ) print( f"Most likely drum track: Track {best_track['track_index'] + 1} ({track_name})" ) print(f"Confidence score: {best_track['score']:.1f}/100") # Show tied tracks if any if len(tied_tracks) > 1: print( f"\n⚠ Note: {len(tied_tracks)} tracks tied with score {top_score:.1f}:" ) for tied in sorted(tied_tracks, key=lambda x: x["track_index"]): tied_name = tied["track_name"] or f"Track {tied['track_index'] + 1}" print( f" - Track {tied['track_index'] + 1} ({tied_name}): " f"{tied['note_count']} notes, " f"max {tied['max_simultaneous']} simultaneous" ) print( " Selected based on: note count, polyphony, note duration, no control changes" ) if best_track["score"] >= 70: print("\n✓ High confidence - This is very likely the drum track") elif best_track["score"] >= 40: print("\n⚠ Medium confidence - This might be the drum track") else: print("\n⚠ Low confidence - Review manually") # Additional analysis for the selected track print(f"\nSelected track details:") print( f" Channel: {sorted(best_track['channels'])[0] + 1 if best_track['channels'] else 'None'}" ) print(f" Total notes: {best_track['note_count']}") print( f" Drum notes: {best_track['drum_note_count']} ({best_track['drum_note_count']/best_track['note_count']*100:.1f}%)" if best_track["note_count"] > 0 else " Drum notes: 0" ) print(f" Max simultaneous: {best_track['max_simultaneous']}") return (best_track["track_index"], best_track) else: print("No track identified as drums (all scores were 0)") return None
if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: python analyze_drum_track.py <midi_player>") print("\nExample:") print(" python analyze_drum_track.py theperfectkiss.mid") sys.exit(1)
[docs] midi_file_path = sys.argv[1]
result = find_drum_track(midi_file_path) if result: track_index, analysis = result print(f"\nTrack index (0-based): {track_index}") print(f"Track index (1-based): {track_index + 1}")