diff --git a/scripts/test.py b/scripts/test.py index 9776699..7d036f3 100755 --- a/scripts/test.py +++ b/scripts/test.py @@ -4,37 +4,101 @@ import dataclasses import json import os import subprocess +import time +import typing + +try: + import alsa_midi as alsa + WITH_MIDI = True +except ModuleNotFoundError as e: + print(f"ERROR: Failed to load module alsa_midi: {e}") + print(" Tests using this module will automatically fail\n" + " if not provided with --skip-midi flag") + exit(1) TEST_DIR = "regression-tests" TEST_DB = "test_db.json" INTERPRETER = "bin/linux/debug/musique" @dataclasses.dataclass -class Result: - exit_code: int = 0 - stdin_lines: list[str] = dataclasses.field(default_factory=list) - stdout_lines: list[str] = dataclasses.field(default_factory=list) - stderr_lines: list[str] = dataclasses.field(default_factory=list) +class MidiEvent: + type: str + args: list[str] + time: float + +def connect_to_default_midi_port(): + global midi_client + midi_client = alsa.SequencerClient('Musique Tester') + ports = midi_client.list_ports(input = True) + + for p in ports: + if p.client_id == 14: + input_port = p + break + else: + assert False, "Linux default MIDI port not found" + + port = midi_client.create_port('Musique tester') + port.connect_from(input_port) + +def listen_for_midi_events() -> list[MidiEvent] | None: + if not WITH_MIDI: + return None + + zero_time = time.monotonic() + + events = [] + while True: + event = midi_client.event_input(timeout=2) + if event is None: + break + end_time = time.monotonic() + events.append((event, end_time - zero_time)) + return events + + +def normalize_events(events) -> typing.Generator[MidiEvent, None, None]: + for event, time in events: + match event: + case alsa.event.NoteOnEvent(): + # TODO Support velocity + yield MidiEvent(type='note_on', args=[str(event.channel), str(event.note)], time=time) + case alsa.event.NoteOffEvent(): + yield MidiEvent(type='note_off', args=[str(event.channel), str(event.note)], time=time) + case _: + assert False, f"Unmatched event type: {event.type}" @dataclasses.dataclass -class TestCase: - name: str +class Result: exit_code: int = 0 - stdin_lines: list[str] = dataclasses.field(default_factory=list) stdout_lines: list[str] = dataclasses.field(default_factory=list) stderr_lines: list[str] = dataclasses.field(default_factory=list) + midi_events: list[MidiEvent] | None = None - def run(self, interpreter: str, source: str, cwd: str): - result = subprocess.run( +@dataclasses.dataclass +class TestCase(Result): + name: str = "" + stdin_lines: list[str] = dataclasses.field(default_factory=list) + + def run(self, interpreter: str, source: str, cwd: str, capture_midi: bool = False) -> Result: + process = subprocess.Popen( args=[interpreter, source, "-q"], - capture_output=True, cwd=cwd, - text=True + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE ) + + midi_events = listen_for_midi_events() if capture_midi else None + stdout, stderr = process.communicate() + return Result( - exit_code=result.returncode, - stdout_lines=result.stdout.splitlines(keepends=False), - stderr_lines=result.stderr.splitlines(keepends=False) + exit_code=process.wait(), + stdout_lines=stdout.splitlines(keepends=False), + stderr_lines=stderr.splitlines(keepends=False), + midi_events = None if midi_events is None else \ + list(normalize_events(midi_events)) ) def record(self, interpreter: str, source: str, cwd: str): @@ -47,13 +111,17 @@ class TestCase: if self.stdout_lines != result.stdout_lines: changes.append("stdout") if changes: print(f" changed: {', '.join(changes)}") - - self.exit_code, self.stderr_lines, self.stdout_lines = result.exit_code, result.stderr_lines, result.stdout_lines + self.exit_code = result.exit_code + self.stderr_lines = result.stderr_lines + self.stdout_lines = result.stdout_lines def test(self, interpreter: str, source: str, cwd: str): print(f" Testing case {self.name} ", end="") result = self.run(interpreter, source, cwd) - if self.exit_code == result.exit_code and self.stdout_lines == result.stdout_lines and self.stderr_lines == result.stderr_lines: + + if self.exit_code == result.exit_code \ + and self.stdout_lines == result.stdout_lines \ + and self.stderr_lines == result.stderr_lines: print("ok") return True @@ -191,9 +259,14 @@ if __name__ == "__main__": parser.add_argument("--update-all", action="store_true", help="Update all tests", dest="update_all") parser.add_argument("-a", "--add", action="append", help="Add new test to test suite", default=[]) parser.add_argument("-u", "--update", action="append", help="Update test case", default=[]) + parser.add_argument('--skip-midi', action="store_true", help="Skip tests expecting MIDI communication", default=False, dest="skip_midi") args = parser.parse_args() + WITH_MIDI = WITH_MIDI and not args.skip_midi + if WITH_MIDI: + connect_to_default_midi_port() + root = os.path.dirname(os.path.dirname(__file__)) testing_dir = os.path.join(root, TEST_DIR) test_db_path = os.path.join(testing_dir, TEST_DB)