Source code for jdxi_editor.ui.editors.pattern.learner

"""
Pattern Learner Module

Handles MIDI input learning for the Pattern Sequencer. Captures incoming MIDI notes
and records them into the pattern grid in real-time, supporting both note-on and
note-off events with automatic step advancement.
"""

from typing import Callable, Dict, List, Optional

from decologr import Decologr as log
from mido import Message

from jdxi_editor.midi.conversion.note import MidiNoteConverter
from picomidi.message.type import MidoMessageType
from picomidi.messages.note import MidiNote


[docs] class PatternLearnerState: """Enumeration of learning states."""
[docs] IDLE = "idle"
[docs] LEARNING = "learning"
[docs] STOPPED = "stopped"
[docs] class PatternLearnerConfig: """Configuration for the pattern learner.""" def __init__( self, total_steps: int = 16, total_rows: int = 4, default_velocity: int = 100, default_duration_ms: float = 120.0, ): """ Initialize learner configuration. :param total_steps: Number of steps per bar (default 16) :param total_rows: Number of rows (default 4) :param default_velocity: Default velocity for learned notes :param default_duration_ms: Default duration in milliseconds """
[docs] self.total_steps = total_steps
[docs] self.total_rows = total_rows
[docs] self.default_velocity = default_velocity
[docs] self.default_duration_ms = default_duration_ms
[docs] class PatternLearnerEvent: """Represents a learned pattern event.""" def __init__( self, step: int, row: int, note: int, velocity: int, duration_ms: float, midi_note: MidiNote | None = None, ): """ Initialize a learned event. :param step: Step number (0-15) :param row: Row number (0-3) :param note: MIDI note number :param velocity: Note velocity (0-127) :param duration_ms: Note duration in milliseconds """
[docs] self.step = step
[docs] self.row = row
[docs] self.note = note
[docs] self.velocity = velocity
[docs] self.duration_ms = duration_ms
[docs] self.midi_note = midi_note or MidiNote( note=note, velocity=velocity, duration_ms=duration_ms )
[docs] class PatternLearner: """ Learns pattern from incoming MIDI messages. Captures MIDI note-on and note-off events and records them into a pattern grid. Automatically advances through steps as notes are released. """ def __init__( self, config: Optional[PatternLearnerConfig] = None, midi_converter: Optional[MidiNoteConverter] = None, scope: str = "PatternLearner", ): """ Initialize the pattern learner. :param config: Learner configuration. Defaults to standard 16-step, 4-row configuration :param midi_converter: MIDI note converter for validating notes. If not provided, uses default conversion without drum validation :param scope: Logging scope name """
[docs] self.config = config or PatternLearnerConfig()
[docs] self.midi_converter = midi_converter
[docs] self.scope = scope
# Learning state
[docs] self.state = PatternLearnerState.IDLE
[docs] self.current_step = 0
[docs] self.current_measure_index = 0
# Learned pattern storage # learned_pattern[row][step] = MIDI note number or None
[docs] self.learned_pattern: List[List[Optional[int]]] = [ [None] * self.config.total_steps for _ in range(self.config.total_rows) ]
# Track which notes are currently active (for matching note-off events) # active_notes[midi_note] = row_index
[docs] self.active_notes: Dict[int, int] = {}
# MIDI track for recording raw messages from mido import MidiTrack
[docs] self.midi_track = MidiTrack()
# Callbacks for external event handling
[docs] self.on_step_advance: Optional[Callable[[int], None]] = None
[docs] self.on_note_learned: Optional[Callable[[PatternLearnerEvent], None]] = None
[docs] self.on_learning_stopped: Optional[Callable[[], None]] = None
[docs] self.on_learning_started: Optional[Callable[[], None]] = None
# Events that occurred during learning
[docs] self.learned_events: List[PatternLearnerEvent] = []
[docs] def start_learning(self) -> None: """Start the pattern learning process.""" if self.state == PatternLearnerState.LEARNING: log.debug( message="Learning already in progress", scope=self.scope, ) return self.state = PatternLearnerState.LEARNING self.current_step = 0 self.active_notes.clear() self.learned_pattern = [ [None] * self.config.total_steps for _ in range(self.config.total_rows) ] self.learned_events.clear() self.midi_track = __import__("mido").MidiTrack() log.message( message="Pattern learning started", scope=self.scope, ) if self.on_learning_started: self.on_learning_started()
[docs] def stop_learning(self) -> None: """Stop the pattern learning process.""" if self.state != PatternLearnerState.LEARNING: log.debug( message="No learning in progress", scope=self.scope, ) return self.state = PatternLearnerState.STOPPED log.message( message=f"Pattern learning stopped after {self.current_step} steps", scope=self.scope, ) if self.on_learning_stopped: self.on_learning_stopped()
[docs] def process_midi_message(self, message: Message) -> None: """ Process an incoming MIDI message during learning. Handles NOTE_ON (velocity > 0) and NOTE_OFF messages. Automatically advances the step when a note-off is received. :param message: Mido Message object """ if self.state != PatternLearnerState.LEARNING: return try: if message.type == MidoMessageType.NOTE_ON.value and message.velocity > 0: self._handle_note_on(message) elif message.type == MidoMessageType.NOTE_OFF.value or ( message.type == MidoMessageType.NOTE_ON.value and message.velocity == 0 ): self._handle_note_off(message) except Exception as ex: log.error( message=f"Error processing MIDI message: {ex}", scope=self.scope, )
[docs] def _handle_note_on(self, message: Message) -> None: """ Handle a NOTE_ON message. Records the note in the learned pattern and marks it as active. :param message: Mido NOTE_ON message with velocity > 0 """ note = message.note velocity = message.velocity # Determine which row this note belongs to row = self._find_row_for_note(note) if row is None: log.debug( message=f"Note {note} does not match any row range", scope=self.scope, ) return step_in_bar = self.current_step % self.config.total_steps # Record the note in the learned pattern self.learned_pattern[row][step_in_bar] = note self.active_notes[note] = row # Get default duration duration_ms = self.config.default_duration_ms # Create and store learned event event = PatternLearnerEvent( step=step_in_bar, row=row, note=note, velocity=velocity, duration_ms=duration_ms, ) self.learned_events.append(event) # Add to MIDI track self.midi_track.append( Message( MidoMessageType.NOTE_ON.value, note=note, velocity=velocity, time=0, ) ) log.message( message=f"Note learned: row {row}, step {step_in_bar}, note {note}, velocity {velocity}", scope=self.scope, ) if self.on_note_learned: self.on_note_learned(event)
[docs] def _handle_note_off(self, message: Message) -> None: """ Handle a NOTE_OFF message or NOTE_ON with velocity 0. Marks the note as inactive and advances to the next step. :param message: Mido NOTE_OFF message or NOTE_ON with velocity 0 """ note = message.note if note not in self.active_notes: log.debug( message=f"Note off for unrecognized note {note}", scope=self.scope, ) return # Remove the note from active notes row = self.active_notes.pop(note) log.message( message=f"Note off: note {note} (row {row}) at step {self.current_step}", scope=self.scope, ) # Add note_off to MIDI track self.midi_track.append( Message(MidoMessageType.NOTE_OFF.value, note=note, velocity=0, time=0) ) # Advance to next step self._advance_step()
[docs] def _advance_step(self) -> None: """Advance to the next step in the pattern.""" prev_step = self.current_step self.current_step = (self.current_step + 1) % self.config.total_steps # Stop learning after completing one full bar (16 steps) if self.current_step == 0: log.message( message="Learning complete after one full bar (16 steps)", scope=self.scope, ) self.stop_learning() return log.message( message=f"Advanced to step {self.current_step}", scope=self.scope, ) if self.on_step_advance: self.on_step_advance(self.current_step)
[docs] def _find_row_for_note(self, midi_note: int) -> Optional[int]: """ Find the row that corresponds to a MIDI note. Uses the MIDI converter if available, otherwise falls back to hardcoded ranges. :param midi_note: MIDI note number :return: Row index (0-3) or None if note doesn't match any row """ if self.midi_converter: for row in range(self.config.total_rows): if self.midi_converter.is_note_in_row_range(row, midi_note): return row return None # Fallback to hardcoded ranges ranges = { 0: range(60, 72), # C4 to B4 (Digital Synth 1) 1: range(60, 72), # C4 to B4 (Digital Synth 2) 2: range(48, 60), # C3 to B3 (Analog Synth) 3: range(36, 48), # C2 to B2 (Drums) } for row, note_range in ranges.items(): if midi_note in note_range: return row return None
[docs] def clear_learned_pattern(self) -> None: """Clear the learned pattern and reset state.""" self.learned_pattern = [ [None] * self.config.total_steps for _ in range(self.config.total_rows) ] self.learned_events.clear() self.active_notes.clear() self.current_step = 0 self.midi_track = __import__("mido").MidiTrack() log.message( message="Cleared learned pattern", scope=self.scope, )
[docs] def get_learned_pattern(self) -> List[List[Optional[int]]]: """ Get the learned pattern. :return: 2D list where pattern[row][step] = MIDI note number or None """ return self.learned_pattern
[docs] def get_learned_events(self) -> List[PatternLearnerEvent]: """ Get all events that occurred during learning. :return: List of PatternLearnerEvent objects """ return self.learned_events
[docs] def get_pattern_for_row(self, row: int) -> List[Optional[int]]: """ Get the learned pattern for a specific row. :param row: Row index (0-3) :return: List of MIDI notes for the row (or None for empty steps) """ if 0 <= row < len(self.learned_pattern): return self.learned_pattern[row] return [None] * self.config.total_steps
[docs] def get_midi_track(self): """ Get the recorded MIDI track. :return: MidiTrack object with NOTE_ON and NOTE_OFF messages """ return self.midi_track
[docs] def is_learning(self) -> bool: """Check if currently learning.""" return self.state == PatternLearnerState.LEARNING
[docs] def set_config(self, config: PatternLearnerConfig) -> None: """ Update the learner configuration. :param config: New configuration """ self.config = config
[docs] def set_midi_converter(self, converter: MidiNoteConverter) -> None: """ Set or update the MIDI converter. :param converter: MidiNoteConverter instance """ self.midi_converter = converter
# Example usage if __name__ == "__main__": from jdxi_editor.midi.conversion.note import MidiNoteConverter # Create converter and learner
[docs] converter = MidiNoteConverter()
learner = PatternLearner(midi_converter=converter) # Set up callbacks def on_note_learned(event: PatternLearnerEvent): log.message( f"Note learned: Row {event.row}, Step {event.step}, " f"Note {event.note}, Velocity {event.velocity}" ) def on_step_advance(step: int): log.message(f"Advanced to step {step}") learner.on_note_learned = on_note_learned learner.on_step_advance = on_step_advance # Simulate learning learner.start_learning() # Simulate some MIDI input learner.process_midi_message( Message(MidoMessageType.NOTE_ON.value, note=60, velocity=100) ) learner.process_midi_message( Message(MidoMessageType.NOTE_OFF.value, note=60, velocity=0) ) learner.process_midi_message( Message(MidoMessageType.NOTE_ON.value, note=64, velocity=80) ) learner.process_midi_message( Message(MidoMessageType.NOTE_OFF.value, note=64, velocity=0) ) # Get results log.message("\nLearned pattern:", learner.get_learned_pattern()) log.message("Learned events:", len(learner.get_learned_events()))