Spaces:
Running
Running
| """Tests for app.py helper functions (no Gradio server required).""" | |
| import json | |
| from pathlib import Path | |
| import pytest | |
| from app import ( | |
| _parse_json_payload, | |
| config_event_bridge, | |
| get_config, | |
| save_midi_api, | |
| save_midi_event_bridge, | |
| process_engine_payload, | |
| process_engine_stream_cpu, | |
| ) | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| # --------------------------------------------------------------------------- | |
| # _parse_json_payload | |
| # --------------------------------------------------------------------------- | |
| class TestParseJsonPayload: | |
| def test_valid_json(self): | |
| assert _parse_json_payload('{"a": 1}', {}) == {"a": 1} | |
| def test_none_returns_default(self): | |
| assert _parse_json_payload(None, "fallback") == "fallback" | |
| def test_empty_string_returns_default(self): | |
| assert _parse_json_payload("", []) == [] | |
| def test_invalid_json_returns_default(self): | |
| assert _parse_json_payload("not json", 42) == 42 | |
| # --------------------------------------------------------------------------- | |
| # get_config | |
| # --------------------------------------------------------------------------- | |
| class TestGetConfig: | |
| def test_returns_instruments(self): | |
| cfg = get_config() | |
| assert "instruments" in cfg | |
| assert len(cfg["instruments"]) > 0 | |
| def test_returns_keyboard_keys(self): | |
| cfg = get_config() | |
| assert "keyboard_keys" in cfg | |
| assert isinstance(cfg["keyboard_keys"], list) | |
| def test_returns_keyboard_shortcuts(self): | |
| cfg = get_config() | |
| assert "keyboard_shortcuts" in cfg | |
| def test_returns_engines(self): | |
| cfg = get_config() | |
| assert "engines" in cfg | |
| engine_ids = [e["id"] for e in cfg["engines"]] | |
| assert "parrot" in engine_ids | |
| def test_returns_runtime(self): | |
| cfg = get_config() | |
| assert "runtime" in cfg | |
| assert "gpu_available" in cfg["runtime"] | |
| assert "default_mode" in cfg["runtime"] | |
| # --------------------------------------------------------------------------- | |
| # config_event_bridge | |
| # --------------------------------------------------------------------------- | |
| def test_config_event_bridge_returns_json(): | |
| raw = config_event_bridge("{}") | |
| parsed = json.loads(raw) | |
| assert "instruments" in parsed | |
| # --------------------------------------------------------------------------- | |
| # save_midi_api | |
| # --------------------------------------------------------------------------- | |
| class TestSaveMidiApi: | |
| def test_returns_base64(self, single_note_events): | |
| result = save_midi_api(single_note_events) | |
| assert "midi_base64" in result | |
| assert isinstance(result["midi_base64"], str) | |
| def test_empty_list(self): | |
| result = save_midi_api([]) | |
| assert "error" in result | |
| def test_non_list(self): | |
| result = save_midi_api("not a list") | |
| assert "error" in result | |
| # --------------------------------------------------------------------------- | |
| # save_midi_event_bridge | |
| # --------------------------------------------------------------------------- | |
| def test_save_midi_bridge_round_trip(single_note_events): | |
| payload = json.dumps(single_note_events) | |
| raw = save_midi_event_bridge(payload) | |
| result = json.loads(raw) | |
| assert "midi_base64" in result | |
| # --------------------------------------------------------------------------- | |
| # process_engine_payload (parrot — no model download) | |
| # --------------------------------------------------------------------------- | |
| class TestProcessEnginePayload: | |
| def test_parrot_round_trip(self, melody_events): | |
| payload = {"engine_id": "parrot", "events": melody_events} | |
| result = process_engine_payload(payload, device="cpu") | |
| assert result.get("success") is True | |
| assert len(result["events"]) == len(melody_events) | |
| def test_missing_engine_id(self, melody_events): | |
| result = process_engine_payload({"events": melody_events}, device="cpu") | |
| assert "error" in result | |
| def test_missing_events(self): | |
| result = process_engine_payload({"engine_id": "parrot"}, device="cpu") | |
| assert "error" in result | |
| def test_invalid_payload_type(self): | |
| result = process_engine_payload("not a dict", device="cpu") | |
| assert "error" in result | |
| def test_unknown_engine(self, melody_events): | |
| payload = {"engine_id": "does_not_exist", "events": melody_events} | |
| result = process_engine_payload(payload, device="cpu") | |
| assert "error" in result | |
| # --------------------------------------------------------------------------- | |
| # process_engine_stream_cpu | |
| # --------------------------------------------------------------------------- | |
| class TestProcessEngineStreamCpu: | |
| def test_non_streaming_engine_yields_single_complete(self, single_note_events): | |
| """parrot has no process_streaming; bridge should yield one complete result.""" | |
| payload = json.dumps({"engine_id": "parrot", "events": single_note_events}) | |
| results = list(process_engine_stream_cpu(payload)) | |
| assert len(results) == 1 | |
| parsed = json.loads(results[0]) | |
| assert parsed["status"] == "complete" | |
| assert parsed.get("success") is True | |
| def test_missing_engine_id_yields_error(self, single_note_events): | |
| payload = json.dumps({"events": single_note_events}) | |
| results = list(process_engine_stream_cpu(payload)) | |
| assert len(results) == 1 | |
| parsed = json.loads(results[0]) | |
| assert parsed["status"] == "error" | |
| assert "error" in parsed | |
| def test_invalid_json_yields_error(self): | |
| results = list(process_engine_stream_cpu("not json")) | |
| assert len(results) == 1 | |
| parsed = json.loads(results[0]) | |
| assert parsed["status"] == "error" | |
| def test_unknown_engine_yields_error(self, single_note_events): | |
| payload = json.dumps({"engine_id": "does_not_exist", "events": single_note_events}) | |
| results = list(process_engine_stream_cpu(payload)) | |
| assert len(results) == 1 | |
| parsed = json.loads(results[0]) | |
| assert parsed["status"] == "error" | |
| def test_missing_events_yields_error(self): | |
| payload = json.dumps({"engine_id": "parrot"}) | |
| results = list(process_engine_stream_cpu(payload)) | |
| assert len(results) == 1 | |
| parsed = json.loads(results[0]) | |
| assert parsed["status"] == "error" | |
| # --------------------------------------------------------------------------- | |
| # Static asset and deployment checks | |
| # --------------------------------------------------------------------------- | |
| def test_allowed_paths_in_app_source(app_source): | |
| """demo.launch() must include allowed_paths=['static'] for logo serving.""" | |
| assert "allowed_paths" in app_source | |
| assert '"static"' in app_source or "'static'" in app_source | |