asigalov61 commited on
Commit
7844cbe
·
verified ·
1 Parent(s): 8256dcc

Upload 2 files

Browse files
Files changed (2) hide show
  1. TCUPY.py +44 -3
  2. TMIDIX.py +934 -3
TCUPY.py CHANGED
@@ -9,14 +9,14 @@ r'''############################################################################
9
  #
10
  # Project Los Angeles
11
  #
12
- # Tegridy Code 2025
13
  #
14
  # https://github.com/asigalov61/tegridy-tools
15
  #
16
  #
17
  ################################################################################
18
  #
19
- # Copyright 2024 Project Los Angeles / Tegridy Code
20
  #
21
  # Licensed under the Apache License, Version 2.0 (the "License");
22
  # you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@ r'''############################################################################
36
  # Critical dependencies
37
  #
38
  # !pip install cupy-cuda12x
39
- # !pip install numpy==1.24.4
40
  #
41
  ################################################################################
42
  '''
@@ -52,6 +52,7 @@ print('=' * 70)
52
 
53
  import sys
54
  import os
 
55
 
56
  ################################################################################
57
 
@@ -1192,6 +1193,46 @@ def embeddings_topk_cosine_neighbors(embeddings,
1192
 
1193
  ###################################################################################
1194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1195
  print('Module is loaded!')
1196
  print('Enjoy! :)')
1197
  print('=' * 70)
 
9
  #
10
  # Project Los Angeles
11
  #
12
+ # Tegridy Code 2026
13
  #
14
  # https://github.com/asigalov61/tegridy-tools
15
  #
16
  #
17
  ################################################################################
18
  #
19
+ # Copyright 2026 Project Los Angeles / Tegridy Code
20
  #
21
  # Licensed under the Apache License, Version 2.0 (the "License");
22
  # you may not use this file except in compliance with the License.
 
36
  # Critical dependencies
37
  #
38
  # !pip install cupy-cuda12x
39
+ # !pip install numpy==1.26.4
40
  #
41
  ################################################################################
42
  '''
 
52
 
53
  import sys
54
  import os
55
+ import tqdm
56
 
57
  ################################################################################
58
 
 
1193
 
1194
  ###################################################################################
1195
 
1196
+ def find_matches_fast(src_array, trg_array, seed: int = 0) -> int:
1197
+
1198
+ """
1199
+ Count how many rows in src_array also appear in trg_array using CuPy (GPU).
1200
+ Uses a non-linear 64-bit FNV-1a hash over raw bytes to avoid collisions.
1201
+ """
1202
+
1203
+ src = cp.ascontiguousarray(cp.asarray(src_array))
1204
+ trg = cp.ascontiguousarray(cp.asarray(trg_array))
1205
+
1206
+ if src.dtype != trg.dtype or src.ndim != 2 or trg.ndim != 2 or src.shape[1] != trg.shape[1]:
1207
+ raise ValueError("src and trg must be 2D arrays with same dtype and same number of columns")
1208
+
1209
+ # bytes per row
1210
+ bpr = src.dtype.itemsize * src.shape[1]
1211
+
1212
+ # view rows as bytes
1213
+ src_bytes = src.view(cp.uint8).reshape(src.shape[0], bpr)
1214
+ trg_bytes = trg.view(cp.uint8).reshape(trg.shape[0], bpr)
1215
+
1216
+ # FNV-1a constants
1217
+ FNV_OFFSET = cp.uint64(0xcbf29ce484222325 ^ seed)
1218
+ FNV_PRIME = cp.uint64(0x100000001b3)
1219
+
1220
+ # hash rows
1221
+ def fnv1a_hash(byte_matrix):
1222
+ h = cp.full((byte_matrix.shape[0],), FNV_OFFSET, dtype=cp.uint64)
1223
+ for i in range(bpr):
1224
+ h ^= byte_matrix[:, i].astype(cp.uint64)
1225
+ h *= FNV_PRIME
1226
+ return h
1227
+
1228
+ src_fp = fnv1a_hash(src_bytes)
1229
+ trg_fp = fnv1a_hash(trg_bytes)
1230
+
1231
+ # count matches
1232
+ return int(cp.isin(src_fp, trg_fp).sum())
1233
+
1234
+ ###################################################################################
1235
+
1236
  print('Module is loaded!')
1237
  print('Enjoy! :)')
1238
  print('=' * 70)
TMIDIX.py CHANGED
@@ -51,7 +51,7 @@ r'''############################################################################
51
 
52
  ###################################################################################
53
 
54
- __version__ = "26.2.27"
55
 
56
  print('=' * 70)
57
  print('TMIDIX Python module')
@@ -1503,7 +1503,9 @@ import statistics
1503
  import math
1504
  from math import gcd
1505
 
1506
- from functools import reduce
 
 
1507
 
1508
  import matplotlib.pyplot as plt
1509
 
@@ -1522,7 +1524,7 @@ from array import array
1522
  from pathlib import Path
1523
  from fnmatch import fnmatch
1524
 
1525
- from typing import List, Optional, Tuple, Dict
1526
 
1527
  ###################################################################################
1528
  #
@@ -16657,6 +16659,935 @@ def merge_text_files(files,
16657
  print(f"Merged {len(files)} files into {output_path}")
16658
 
16659
  ###################################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16660
 
16661
  print('Module loaded!')
16662
  print('=' * 70)
 
51
 
52
  ###################################################################################
53
 
54
+ __version__ = "26.3.2"
55
 
56
  print('=' * 70)
57
  print('TMIDIX Python module')
 
1503
  import math
1504
  from math import gcd
1505
 
1506
+ from functools import reduce, lru_cache
1507
+
1508
+ import struct
1509
 
1510
  import matplotlib.pyplot as plt
1511
 
 
1524
  from pathlib import Path
1525
  from fnmatch import fnmatch
1526
 
1527
+ from typing import List, Optional, Tuple, Dict, Any
1528
 
1529
  ###################################################################################
1530
  #
 
16659
  print(f"Merged {len(files)} files into {output_path}")
16660
 
16661
  ###################################################################################
16662
+
16663
+ def chord_cost(input_pc,
16664
+ candidate,
16665
+ del_white=5.0,
16666
+ del_black=1.0,
16667
+ ins_white=1.0,
16668
+ ins_black=5.0,
16669
+ col_change_w2b=3.0,
16670
+ col_change_b2w=0.5
16671
+ ):
16672
+
16673
+ """
16674
+ Compute minimal cost to transform input_pc into candidate.
16675
+ Costs are tuned to preserve white notes and avoid introducing black notes.
16676
+ """
16677
+
16678
+ m, n = len(input_pc), len(candidate)
16679
+
16680
+ # Circular distance matrix
16681
+ dist = [[min(abs(a - b), 12 - abs(a - b)) for b in candidate] for a in input_pc]
16682
+
16683
+ # Color: 1 for black, 0 for white
16684
+ col_in = [1 if note in BLACK_NOTES else 0 for note in input_pc]
16685
+ col_cand = [1 if note in BLACK_NOTES else 0 for note in candidate]
16686
+
16687
+ # Cost parameters
16688
+ DEL_WHITE = del_white # deleting a white note is very undesirable
16689
+ DEL_BLACK = del_black # deleting a black note is cheap
16690
+ INS_WHITE = ins_white # adding a white note is acceptable
16691
+ INS_BLACK = ins_black # adding a black note is heavily penalised
16692
+ COL_CHANGE_W2B = col_change_w2b # white → black is bad
16693
+ COL_CHANGE_B2W = col_change_b2w # black → white is slightly encouraged
16694
+
16695
+ @lru_cache(maxsize=None)
16696
+ def dfs(i, used_mask):
16697
+ if i == m:
16698
+ # All input processed: add insertion cost for any unused candidate notes
16699
+ cost = 0.0
16700
+ for j in range(n):
16701
+ if not (used_mask >> j) & 1:
16702
+ cost += INS_WHITE if col_cand[j] == 0 else INS_BLACK
16703
+ return cost
16704
+
16705
+ # Option 1: delete current input note
16706
+ best = (DEL_WHITE if col_in[i] == 0 else DEL_BLACK) + dfs(i + 1, used_mask)
16707
+
16708
+ # Option 2: match to an unused candidate
16709
+ for j in range(n):
16710
+ if not (used_mask >> j) & 1:
16711
+ d = dist[i][j]
16712
+ # Color change penalty
16713
+ if col_in[i] == 1 and col_cand[j] == 0:
16714
+ d += COL_CHANGE_B2W
16715
+ elif col_in[i] == 0 and col_cand[j] == 1:
16716
+ d += COL_CHANGE_W2B
16717
+ best = min(best, d + dfs(i + 1, used_mask | (1 << j)))
16718
+ return best
16719
+
16720
+ return dfs(0, 0)
16721
+
16722
+ ###################################################################################
16723
+
16724
+ def expert_check_and_fix_tones_chord(tones_chord, use_full_chords=False, **kwargs):
16725
+
16726
+ """
16727
+ Given a list of pitch classes (0-11), return the closest valid chord
16728
+ from the selected list using a musically informed cost function.
16729
+
16730
+ ------
16731
+ KWARGS
16732
+ ------
16733
+
16734
+ # Cost parameters
16735
+ del_white = 5.0 # deleting a white note is very undesirable
16736
+ del_black = 1.0 # deleting a black note is cheap
16737
+ ins_white = 1.0 # adding a white note is acceptable
16738
+ ins_black = 5.0 # adding a black note is heavily penalised
16739
+ col_change_w2b = 3.0 # white → black is bad
16740
+ col_change_b2w = 0.5 # black → white is slightly encouraged
16741
+ """
16742
+
16743
+ tones_chord = sorted(set(tones_chord))
16744
+
16745
+ if not tones_chord:
16746
+ return []
16747
+
16748
+ if use_full_chords:
16749
+ CHORDS = ALL_CHORDS_FULL
16750
+ else:
16751
+ CHORDS = ALL_CHORDS_SORTED
16752
+
16753
+ # Exact match
16754
+ if tones_chord in CHORDS:
16755
+ return tones_chord
16756
+
16757
+ best_chord = None
16758
+ best_cost = float('inf')
16759
+
16760
+ for chord in CHORDS:
16761
+ cost = chord_cost(tones_chord, chord, **kwargs)
16762
+ if cost < best_cost:
16763
+ best_cost = cost
16764
+ best_chord = chord
16765
+ elif cost == best_cost and best_chord is not None:
16766
+ # Tie‑breaker: prefer chord with fewer black notes
16767
+ black_best = sum(1 for n in best_chord if n in BLACK_NOTES)
16768
+ black_curr = sum(1 for n in chord if n in BLACK_NOTES)
16769
+ if black_curr < black_best:
16770
+ best_chord = chord
16771
+
16772
+ return sorted(best_chord) if best_chord else []
16773
+
16774
+ ###################################################################################
16775
+
16776
+ def expert_check_and_fix_pitches_chord(pitches_chord, use_full_chords=False, **kwargs):
16777
+
16778
+ if use_full_chords:
16779
+ CHORDS = ALL_CHORDS_FULL
16780
+ else:
16781
+ CHORDS = ALL_CHORDS_SORTED
16782
+
16783
+ pitches = sorted(set(pitches_chord), reverse=True)
16784
+
16785
+ fixed_tones_chord = tones_chord = sorted(set([p % 12 for p in pitches]))
16786
+
16787
+ if tones_chord not in CHORDS:
16788
+ fixed_tones_chord = expert_check_and_fix_tones_chord(tones_chord,
16789
+ use_full_chords=use_full_chords,
16790
+ **kwargs
16791
+ )
16792
+
16793
+ same_tones = sorted(set(tones_chord) & set(fixed_tones_chord))
16794
+ new_tones = sorted(set(same_tones) ^ set(fixed_tones_chord))
16795
+
16796
+ good_pitches = [p for p in pitches if p % 12 in same_tones]
16797
+ bad_pitches = [p for p in pitches if p % 12 not in same_tones]
16798
+
16799
+ new_pitches = []
16800
+
16801
+ for p in pitches:
16802
+ if p not in bad_pitches:
16803
+ new_pitches.append(p)
16804
+
16805
+ else:
16806
+ octave = (p // 12)
16807
+
16808
+ if octave > 4:
16809
+ octave -= 1
16810
+
16811
+ else:
16812
+ octave += 1
16813
+
16814
+ tone = p % 12
16815
+
16816
+ if new_tones:
16817
+ ntone = find_closest_tone(new_tones, tone)
16818
+
16819
+ new_pitch = (octave * 12)+ntone
16820
+
16821
+ while new_pitch in good_pitches or new_pitch in new_pitches:
16822
+ octave -= 1
16823
+ new_pitch = (octave * 12)+ntone
16824
+
16825
+ new_pitches.append(new_pitch)
16826
+
16827
+ else:
16828
+
16829
+ ntone = find_closest_tone(same_tones, tone)
16830
+
16831
+ new_pitch = (octave * 12)+ntone
16832
+
16833
+ while new_pitch in good_pitches or new_pitch in new_pitches:
16834
+ octave -= 1
16835
+ new_pitch = (octave * 12)+ntone
16836
+
16837
+ new_pitches.append(new_pitch)
16838
+
16839
+ return sorted(new_pitches, reverse=True)
16840
+
16841
+ ###################################################################################
16842
+
16843
+ def split_escore_notes_by_channel(escore_notes, chan_idx=3):
16844
+
16845
+ chan_groups = groupby(sorted(escore_notes, key=lambda x: x[chan_idx]), key=lambda x: x[chan_idx])
16846
+
16847
+ return {k: list(v) for k, v in chan_groups}
16848
+
16849
+ ###################################################################################
16850
+
16851
+ def split_escore_notes_by_patch(escore_notes, pat_idx=6):
16852
+
16853
+ chan_groups = groupby(sorted(escore_notes, key=lambda x: x[pat_idx]), key=lambda x: x[pat_idx])
16854
+
16855
+ return {k: list(v) for k, v in chan_groups}
16856
+
16857
+ ###################################################################################
16858
+
16859
+ def expert_check_and_fix_chords_in_escore_notes(escore_notes,
16860
+ use_full_chords=False,
16861
+ split_by_channel=False,
16862
+ **kwargs
16863
+ ):
16864
+
16865
+ cscore = chordify_score([1000, escore_notes])
16866
+
16867
+ fixed_score = []
16868
+
16869
+ for c in cscore:
16870
+
16871
+ if split_by_channel:
16872
+ pat_groups = split_escore_notes_by_channel(c)
16873
+ drumsg = 9
16874
+
16875
+ else:
16876
+ pat_groups = split_escore_notes_by_patch(c)
16877
+ drumsg = 128
16878
+
16879
+ for pat, evs in pat_groups.items():
16880
+
16881
+ if pat != drumsg:
16882
+ evs_set = []
16883
+ seen = set()
16884
+
16885
+ for e in evs:
16886
+ if e[4] not in seen:
16887
+ evs_set.append(e)
16888
+ seen.add(e[4])
16889
+
16890
+ evs_set = sorted(evs_set, key=lambda x: -x[4])
16891
+
16892
+ pitches_chord = [e[4] for e in evs_set]
16893
+
16894
+ fixed_pitches_chord = expert_check_and_fix_pitches_chord(pitches_chord,
16895
+ use_full_chords=use_full_chords,
16896
+ **kwargs
16897
+ )
16898
+
16899
+ fixed_chord = []
16900
+
16901
+ for i, e in enumerate(evs_set):
16902
+ ee = copy.deepcopy(e)
16903
+
16904
+ ee[4] = fixed_pitches_chord[i]
16905
+ fixed_score.append(ee)
16906
+
16907
+ else:
16908
+ fixed_score.extend(evs)
16909
+
16910
+ return sorted(fixed_score, key=lambda x: (x[1], -x[4], x[6]) if x[6] != 128 else (x[1], x[6], -x[4]))
16911
+
16912
+ ###################################################################################
16913
+
16914
+ def sparse_random_int_list(length, sparsity=0.01, value_range=(1, 100)):
16915
+
16916
+ """
16917
+ Create a highly sparse list of random integers.
16918
+
16919
+ length: total length of the list
16920
+ sparsity: probability that a given index contains a non-zero value (0.01 = 1%)
16921
+ value_range: (min, max) range for random integers
16922
+ """
16923
+
16924
+ low, high = value_range
16925
+
16926
+ return [
16927
+ random.randint(low, high) if random.random() < sparsity else 0
16928
+ for _ in range(length)
16929
+ ]
16930
+
16931
+ ###################################################################################
16932
+
16933
+ def detect_list_values_type(values):
16934
+
16935
+ """
16936
+ Detect the most specific type that can represent all values in the list.
16937
+ Returns one of:
16938
+ 'bool', 'byte', 'int8', 'int16', 'int32', 'int64',
16939
+ 'float32', 'float64', 'object'
16940
+ """
16941
+
16942
+ if not values:
16943
+ return None
16944
+
16945
+ # --- BOOL CHECK ---
16946
+ if all(isinstance(v, bool) for v in values):
16947
+ return "bool"
16948
+
16949
+ # --- INT CHECK ---
16950
+ if all(isinstance(v, int) and not isinstance(v, bool) for v in values):
16951
+ mn, mx = min(values), max(values)
16952
+
16953
+ # byte (unsigned 8-bit)
16954
+ if 0 <= mn and mx <= 255:
16955
+ return "byte"
16956
+
16957
+ # int8 (signed 8-bit)
16958
+ if -128 <= mn and mx <= 127:
16959
+ return "int8"
16960
+
16961
+ # int16
16962
+ if -32768 <= mn and mx <= 32767:
16963
+ return "int16"
16964
+
16965
+ # int32
16966
+ if -2147483648 <= mn and mx <= 2147483647:
16967
+ return "int32"
16968
+
16969
+ # otherwise int64
16970
+ return "int64"
16971
+
16972
+ # --- FLOAT CHECK ---
16973
+ if all(isinstance(v, float) for v in values):
16974
+
16975
+ def to_float32(x):
16976
+ return struct.unpack("!f", struct.pack("!f", x))[0]
16977
+
16978
+ if all(abs(to_float32(v) - v) < 1e-7 for v in values):
16979
+ return "float32"
16980
+
16981
+ return "float64"
16982
+
16983
+ # --- MIXED TYPES ---
16984
+ return "object"
16985
+
16986
+ ###################################################################################
16987
+
16988
+ # ---------- VarInt helpers (operate on bytearray) ----------
16989
+ def _write_varint_to_bytearray(n: int, out: bytearray) -> None:
16990
+ """Append a VarInt (LE 7-bit groups, MSB continuation) to out."""
16991
+ if n < 0:
16992
+ raise ValueError("VarInt only works with non-negative integers")
16993
+ while True:
16994
+ byte = n & 0x7F
16995
+ n >>= 7
16996
+ if n:
16997
+ byte |= 0x80
16998
+ out.append(byte)
16999
+ if not n:
17000
+ break
17001
+
17002
+ ###################################################################################
17003
+
17004
+ def _read_varint_from_bytearray(data: bytearray, pos: int) -> Tuple[int, int]:
17005
+ """Read a VarInt from data starting at pos; returns (value, new_pos)."""
17006
+ value = 0
17007
+ shift = 0
17008
+ start = pos
17009
+ while True:
17010
+ if pos >= len(data):
17011
+ raise ValueError("Unexpected end of data while reading VarInt")
17012
+ b = data[pos]
17013
+ pos += 1
17014
+ value |= (b & 0x7F) << shift
17015
+ shift += 7
17016
+ if not (b & 0x80):
17017
+ break
17018
+ if shift > 10 * 7: # arbitrary safety for extremely large varints
17019
+ raise ValueError("VarInt too large or malformed (excessive length)")
17020
+ return value, pos
17021
+
17022
+ ###################################################################################
17023
+
17024
+ # ---------- ZigZag (signed ↔ unsigned) ----------
17025
+ def _zigzag_encode(n: int) -> int:
17026
+ """ZigZag encode arbitrary Python int to non-negative int."""
17027
+ if n >= 0:
17028
+ return n << 1
17029
+ else:
17030
+ return ((-n) << 1) - 1
17031
+
17032
+ ###################################################################################
17033
+
17034
+ def _zigzag_decode(n: int) -> int:
17035
+ """Decode ZigZag-encoded non-negative int back to signed int."""
17036
+ return (n >> 1) if (n & 1) == 0 else -((n >> 1) + 1)
17037
+
17038
+ ###################################################################################
17039
+ # ---------- Helpers ----------
17040
+ def _fits_in_signed(bits: int, v: int) -> bool:
17041
+ lo = -(1 << (bits - 1))
17042
+ hi = (1 << (bits - 1)) - 1
17043
+ return lo <= v <= hi
17044
+
17045
+ ###################################################################################
17046
+
17047
+ def _fits_in_unsigned(bits: int, v: int) -> bool:
17048
+ return 0 <= v <= (1 << bits) - 1
17049
+
17050
+ ###################################################################################
17051
+
17052
+ def _choose_fixed_type(values: List[int]) -> str:
17053
+ """Choose smallest fixed-width type that fits all values.
17054
+ Prefers unsigned types when all values >= 0.
17055
+ Returns one of: 'bool','byte'/'uint8','int8','int16','int32','int64'"""
17056
+ if not values:
17057
+ return 'int32' # default when no non-zeros
17058
+ all_nonneg = all(v >= 0 for v in values)
17059
+ if all(v in (0, 1) for v in values):
17060
+ return 'bool'
17061
+ if all_nonneg and all(_fits_in_unsigned(8, v) for v in values):
17062
+ return 'byte' # alias for uint8
17063
+ if all(_fits_in_signed(8, v) for v in values):
17064
+ return 'int8'
17065
+ if all(_fits_in_signed(16, v) for v in values):
17066
+ return 'int16'
17067
+ if all(_fits_in_signed(32, v) for v in values):
17068
+ return 'int32'
17069
+ return 'int64'
17070
+
17071
+ ###################################################################################
17072
+
17073
+ # ---------- Public API ----------
17074
+ def encode_sparse_list(lst: List[int], value_type: Optional[str] = None) -> List[int]:
17075
+ """
17076
+ Compress a sparse list of integers into a list of bytes (ints 0-255).
17077
+
17078
+ Parameters:
17079
+ lst: input list of integers.
17080
+ value_type:
17081
+ None -> auto ZigZag+VarInt for values (best for arbitrary signed ints)
17082
+ 'auto' -> pick smallest fixed-width type (or bool) based on values
17083
+ 'bool' -> store only positions (value implicitly 1)
17084
+ 'byte' or 'uint8' -> store unsigned 8-bit value per non-zero
17085
+ 'int8','int16','int32','int64' -> store fixed-width signed values
17086
+
17087
+ Returns:
17088
+ List[int] of bytes (0-255).
17089
+ """
17090
+ non_zeros = [(i, val) for i, val in enumerate(lst) if val != 0]
17091
+ k = len(non_zeros)
17092
+ n = len(lst)
17093
+
17094
+ # If auto, decide based on values
17095
+ if value_type == 'auto':
17096
+ values = [val for _, val in non_zeros]
17097
+ value_type = _choose_fixed_type(values)
17098
+
17099
+ out = bytearray()
17100
+ _write_varint_to_bytearray(n, out)
17101
+ _write_varint_to_bytearray(k, out)
17102
+
17103
+ prev_idx = 0
17104
+ for idx, val in non_zeros:
17105
+ delta = idx - prev_idx
17106
+ if delta <= 0:
17107
+ raise ValueError("Indices must be strictly increasing")
17108
+ _write_varint_to_bytearray(delta, out)
17109
+
17110
+ if value_type is None:
17111
+ # ZigZag + VarInt
17112
+ _write_varint_to_bytearray(_zigzag_encode(val), out)
17113
+ elif value_type == 'bool':
17114
+ # no value bytes stored; value implicitly 1
17115
+ pass
17116
+ elif value_type in ('byte', 'uint8'):
17117
+ if not _fits_in_unsigned(8, val):
17118
+ raise ValueError(f"value {val} out of range for uint8")
17119
+ out.append(val & 0xFF)
17120
+ elif value_type == 'int8':
17121
+ out.extend(struct.pack('<b', val))
17122
+ elif value_type == 'int16':
17123
+ out.extend(struct.pack('<h', val))
17124
+ elif value_type == 'int32':
17125
+ out.extend(struct.pack('<i', val))
17126
+ elif value_type == 'int64':
17127
+ out.extend(struct.pack('<q', val))
17128
+ else:
17129
+ raise ValueError(f"Unsupported value_type: {value_type}")
17130
+
17131
+ prev_idx = idx
17132
+
17133
+ # Return as list[int] for compatibility
17134
+ return list(out)
17135
+
17136
+ ###################################################################################
17137
+
17138
+ def decode_sparse_list(encoded: List[int], value_type: Optional[str] = None) -> List[int]:
17139
+ """
17140
+ Decompress a list of bytes (ints) back into the original integer list.
17141
+
17142
+ Parameters:
17143
+ encoded: list of bytes (ints 0-255) produced by encode_sparse_list.
17144
+ value_type: must match the type used during encoding. Use 'auto' only if
17145
+ you encoded with 'auto' and stored the chosen type separately.
17146
+
17147
+ Returns:
17148
+ The reconstructed list of integers.
17149
+ """
17150
+ data = bytearray(encoded)
17151
+ pos = 0
17152
+
17153
+ n, pos = _read_varint_from_bytearray(data, pos)
17154
+ k, pos = _read_varint_from_bytearray(data, pos)
17155
+
17156
+ result = [0] * n
17157
+ prev_idx = 0
17158
+
17159
+ for _ in range(k):
17160
+ delta, pos = _read_varint_from_bytearray(data, pos)
17161
+ if delta <= 0:
17162
+ raise ValueError("Invalid delta (must be >= 1)")
17163
+ idx = prev_idx + delta
17164
+
17165
+ if value_type is None:
17166
+ zigzag_val, pos = _read_varint_from_bytearray(data, pos)
17167
+ val = _zigzag_decode(zigzag_val)
17168
+ elif value_type == 'bool':
17169
+ val = 1
17170
+ elif value_type in ('byte', 'uint8'):
17171
+ if pos >= len(data):
17172
+ raise ValueError("Unexpected end of data while reading uint8")
17173
+ val = data[pos]
17174
+ pos += 1
17175
+ elif value_type == 'int8':
17176
+ if pos + 1 > len(data):
17177
+ raise ValueError("Unexpected end of data while reading int8")
17178
+ val = struct.unpack('<b', bytes([data[pos]]))[0]
17179
+ pos += 1
17180
+ elif value_type == 'int16':
17181
+ if pos + 2 > len(data):
17182
+ raise ValueError("Unexpected end of data while reading int16")
17183
+ val = struct.unpack('<h', bytes(data[pos:pos+2]))[0]
17184
+ pos += 2
17185
+ elif value_type == 'int32':
17186
+ if pos + 4 > len(data):
17187
+ raise ValueError("Unexpected end of data while reading int32")
17188
+ val = struct.unpack('<i', bytes(data[pos:pos+4]))[0]
17189
+ pos += 4
17190
+ elif value_type == 'int64':
17191
+ if pos + 8 > len(data):
17192
+ raise ValueError("Unexpected end of data while reading int64")
17193
+ val = struct.unpack('<q', bytes(data[pos:pos+8]))[0]
17194
+ pos += 8
17195
+ else:
17196
+ raise ValueError(f"Unsupported value_type: {value_type}")
17197
+
17198
+ if not (0 <= idx < n):
17199
+ raise IndexError("Decoded index out of range")
17200
+ result[idx] = val
17201
+ prev_idx = idx
17202
+
17203
+ return result
17204
+
17205
+ ###################################################################################
17206
+
17207
+ def shift_to_smallest_integer_type(values: List[Any]) -> Dict[str, Any]:
17208
+
17209
+ """
17210
+ Attempt to shift a list of numeric values by a single integer offset so that
17211
+ all shifted values fit into the smallest standard integer type among:
17212
+ - "byte" : unsigned 8-bit (0 .. 255)
17213
+ - "int8" : signed 8-bit (-128 .. 127)
17214
+ - "int16" : signed 16-bit (-32768 .. 32767)
17215
+ - "int32" : signed 32-bit (-2147483648 .. 2147483647)
17216
+ - "int64" : signed 64-bit (-2**63 .. 2**63 - 1)
17217
+
17218
+ Rules:
17219
+ - Accepts Python `int` and `float` values that are exact integers (e.g., 3.0).
17220
+ - Rejects booleans and non-integer floats (returns original list).
17221
+ - For an empty list returns {"type":"byte","values":[], "offset": 0}.
17222
+ - Chooses the smallest type (in the order above) for which an integer offset k exists
17223
+ satisfying tmin <= v + k <= tmax for all v.
17224
+ - If multiple offsets are valid for a type, prefer k = 0 if possible; otherwise pick
17225
+ the offset inside the valid interval with the smallest absolute value (tie -> smaller numeric).
17226
+ - Return shape when shifted: {"type": <type_name>, "values": <shifted_list>, "offset": <int>}
17227
+ When not shifted/invalid: {"type": "original", "values": <original_list>}
17228
+ """
17229
+
17230
+ # Validate list
17231
+ if not isinstance(values, list):
17232
+ return {"type": "original", "values": values, "offset": 0}
17233
+
17234
+ # Normalize: accept exact-integer floats by converting them to ints
17235
+ normalized: List[int] = []
17236
+ for v in values:
17237
+ # exclude booleans explicitly
17238
+ if isinstance(v, bool):
17239
+ return {"type": "original", "values": values, "offset": 0}
17240
+ if isinstance(v, int):
17241
+ normalized.append(int(v))
17242
+ elif isinstance(v, float):
17243
+ if v.is_integer():
17244
+ normalized.append(int(v))
17245
+ else:
17246
+ return {"type": "original", "values": values, "offset": 0}
17247
+ else:
17248
+ return {"type": "original", "values": values, "offset": 0}
17249
+
17250
+ # Empty list fits in smallest type
17251
+ if len(normalized) == 0:
17252
+ return {"type": "byte", "values": [], "offset": 0}
17253
+
17254
+ vmin = min(normalized)
17255
+ vmax = max(normalized)
17256
+
17257
+ # type definitions: (name, min_allowed, max_allowed)
17258
+ types = [
17259
+ ("byte", 0, 255),
17260
+ ("int8", -128, 127),
17261
+ ("int16", -32768, 32767),
17262
+ ("int32", -2147483648, 2147483647),
17263
+ ("int64", -2**63, 2**63 - 1),
17264
+ ]
17265
+
17266
+ for name, tmin, tmax in types:
17267
+ # k must satisfy: tmin <= v + k <= tmax for all v
17268
+ # so k in [tmin - vmin, tmax - vmax]
17269
+ low = tmin - vmin
17270
+ high = tmax - vmax
17271
+ if low <= high:
17272
+ # prefer 0 if possible
17273
+ if low <= 0 <= high:
17274
+ k = 0
17275
+ else:
17276
+ # choose value in [low, high] with smallest absolute value
17277
+ # tie-breaker: choose the smaller numeric value
17278
+ if abs(low) < abs(high):
17279
+ k = low
17280
+ elif abs(high) < abs(low):
17281
+ k = high
17282
+ else:
17283
+ k = min(low, high)
17284
+ k = int(k)
17285
+ shifted = [v + k for v in normalized]
17286
+ return {"type": name, "values": shifted, "offset": k}
17287
+
17288
+ return {"type": "original", "values": values, "offset": 0}
17289
+
17290
+ ###################################################################################
17291
+
17292
+ def encode_row_zero_counts(row: List[int],
17293
+ chunk: int = 128,
17294
+ verbose: bool = True
17295
+ ) -> List[int]:
17296
+
17297
+ """
17298
+ Encode a binary row as counts of zeros between ones.
17299
+ - For rows with ones: returns [zc0, zc1, ..., zc_last].
17300
+ - For all-zero rows: returns chunk-sized parts plus remainder (e.g., [128] for 128 zeros).
17301
+
17302
+ Configuration: for 128-column rows use CHUNK = 128 and SHIFT > 128 (e.g., 256)
17303
+ """
17304
+
17305
+ if row is None:
17306
+ if verbose:
17307
+ print("row is None")
17308
+
17309
+ n = len(row)
17310
+
17311
+ if n == 0:
17312
+ return [0]
17313
+
17314
+ zeros = 0
17315
+ zero_counts: List[int] = []
17316
+ seen_one = False
17317
+
17318
+ for bit in row:
17319
+ if bit not in (0, 1):
17320
+ if verbose:
17321
+ print("row must contain only 0 or 1")
17322
+
17323
+ if bit == 0:
17324
+ zeros += 1
17325
+
17326
+ else:
17327
+ zero_counts.append(zeros)
17328
+ zeros = 0
17329
+ seen_one = True
17330
+
17331
+ if not seen_one:
17332
+
17333
+ parts: List[int] = []
17334
+ remaining = n
17335
+
17336
+ while remaining >= chunk:
17337
+ parts.append(chunk)
17338
+ remaining -= chunk
17339
+
17340
+ if remaining > 0:
17341
+ parts.append(remaining)
17342
+
17343
+ return parts
17344
+
17345
+ return zero_counts
17346
+
17347
+ ###################################################################################
17348
+
17349
+ def decode_row_zero_counts(zero_counts: List[int],
17350
+ n_cols: int,
17351
+ chunk: int = 128,
17352
+ verbose: bool = True
17353
+ ) -> List[int]:
17354
+
17355
+ """
17356
+ Decode zero_counts into a binary row of length n_cols.
17357
+ Handles chunked all-zero representation (sum(zero_counts) == n_cols).
17358
+ Otherwise decodes as zeros/ones/zeros pattern.
17359
+
17360
+ Configuration: for 128-column rows use CHUNK = 128 and SHIFT > 128 (e.g., 256)
17361
+ """
17362
+
17363
+ if not zero_counts:
17364
+ if verbose:
17365
+ print("zero_counts must be non-empty")
17366
+
17367
+ if any((not isinstance(x, int) or x < 0) for x in zero_counts):
17368
+ if verbose:
17369
+ print("zero_counts must be nonnegative integers")
17370
+
17371
+ total_zeros = sum(zero_counts)
17372
+
17373
+ if total_zeros == n_cols:
17374
+ return [0] * n_cols
17375
+
17376
+ ones = len(zero_counts)
17377
+ if ones < 0:
17378
+ if verbose:
17379
+ print("invalid zero_counts for non-all-zero row")
17380
+
17381
+ if (total_zeros + ones) + (n_cols - (total_zeros + ones)) != n_cols:
17382
+ if verbose:
17383
+ print(total_zeros + ones, (n_cols - (total_zeros + ones)))
17384
+ print(f"zero_counts do not match expected row length: sum={total_zeros}, ones={ones}, n_cols={n_cols}")
17385
+
17386
+ row: List[int] = []
17387
+
17388
+ for i in range(ones):
17389
+ row.extend([0] * zero_counts[i])
17390
+ row.append(1)
17391
+
17392
+ row += [0] * (n_cols - len(row))
17393
+
17394
+ return row
17395
+
17396
+ ###################################################################################
17397
+
17398
+ def encode_matrix_marker_prefixed(matrix: List[List[int]],
17399
+ shift: int = 129,
17400
+ chunk: int = 128,
17401
+ verbose: bool = True
17402
+ ) -> Dict[str, Any]:
17403
+
17404
+ """
17405
+ Encode matrix into a list of entries where each entry is:
17406
+ [marker, zc0, zc1, ...]
17407
+ marker = shift + (repeat_count - 1)
17408
+ - For a single row (no repeats) repeat_count = 1 -> marker = shift + 0 = shift
17409
+ - For k repeated rows repeat_count = k -> marker = shift + (k - 1)
17410
+ Validation ensures all zero_counts < shift so marker is unambiguous.
17411
+ Returns {'shape': (n_rows, n_cols), 'rows': [...]}
17412
+
17413
+ Configuration: for 128-column rows use CHUNK = 128 and SHIFT > 128 (e.g., 256)
17414
+ """
17415
+
17416
+ if matrix is None or len(matrix) == 0:
17417
+ return {'shape': (0, 0), 'rows': []}
17418
+
17419
+ n_rows = len(matrix)
17420
+ n_cols = len(matrix[0])
17421
+ encoded_rows: List[List[int]] = []
17422
+
17423
+ prev_zc = None
17424
+ prev_count = 0
17425
+
17426
+ for row in matrix:
17427
+ if len(row) != n_cols:
17428
+ if verbose:
17429
+ print("All rows must have the same number of columns")
17430
+ return encoded_rows
17431
+
17432
+ zc = encode_row_zero_counts(row, chunk=chunk)
17433
+
17434
+ if any(x >= shift for x in zc):
17435
+ if verbose:
17436
+ print(f"zero_count value >= shift ({shift}). Increase SHIFT or reduce CHUNK. zc={zc}")
17437
+ return encoded_rows
17438
+
17439
+ if prev_zc is None:
17440
+ prev_zc = zc
17441
+ prev_count = 1
17442
+
17443
+ elif zc == prev_zc:
17444
+ prev_count += 1
17445
+
17446
+ else:
17447
+ marker = shift + (prev_count - 1)
17448
+ encoded_rows.append([marker] + prev_zc)
17449
+ prev_zc = zc
17450
+ prev_count = 1
17451
+
17452
+ if prev_zc is not None:
17453
+ marker = shift + (prev_count - 1)
17454
+ encoded_rows.append([marker] + prev_zc)
17455
+
17456
+ return {'shape': (n_rows, n_cols), 'rows': encoded_rows}
17457
+
17458
+ ###################################################################################
17459
+
17460
+ def decode_matrix_marker_prefixed(encoded: Dict[str, Any],
17461
+ shift: int = 129,
17462
+ chunk: int = 128,
17463
+ verbose: bool = True
17464
+ ) -> List[List[int]]:
17465
+
17466
+ """
17467
+ Decode the structure produced by encode_matrix_marker_prefixed.
17468
+ Each entry must be [marker, zc0, zc1, ...] where marker >= shift.
17469
+ The repeat count is (marker - shift + 1).
17470
+
17471
+ Configuration: for 128-column rows use CHUNK = 128 and SHIFT > 128 (e.g., 256)
17472
+ """
17473
+
17474
+ if 'shape' not in encoded or 'rows' not in encoded:
17475
+ if verbose:
17476
+ print("encoded must contain 'shape' and 'rows'")
17477
+ return None
17478
+
17479
+ n_rows, n_cols = encoded['shape']
17480
+ rows_encoded = encoded['rows']
17481
+ matrix: List[List[int]] = []
17482
+ total_rows = 0
17483
+
17484
+ for entry in rows_encoded:
17485
+ if not isinstance(entry, list) or len(entry) == 0:
17486
+ print("each encoded entry must be a non-empty list")
17487
+
17488
+ marker = int(entry[0])
17489
+
17490
+ if marker < shift:
17491
+ if verbose:
17492
+ print(f"marker {marker} < shift {shift}; encoded entries must start with marker")
17493
+
17494
+ repeat_count = (marker - shift) + 1
17495
+
17496
+ if repeat_count < 1:
17497
+ if verbose:
17498
+ print("computed repeat_count < 1")
17499
+
17500
+ zero_counts = entry[1:]
17501
+
17502
+ if any((not isinstance(x, int) or x < 0) for x in zero_counts):
17503
+ if verbose:
17504
+ print("zero_counts must be nonnegative integers")
17505
+
17506
+ if any(x >= shift for x in zero_counts):
17507
+ if verbose:
17508
+ print("zero_counts contain value >= shift; ambiguous marker")
17509
+
17510
+ row = decode_row_zero_counts(zero_counts, n_cols, chunk=chunk)
17511
+
17512
+ for _ in range(repeat_count):
17513
+ matrix.append(list(row))
17514
+
17515
+ total_rows += repeat_count
17516
+
17517
+ if total_rows != n_rows:
17518
+ if verbose:
17519
+ print(f"Decoded row count {total_rows} does not match shape {n_rows}")
17520
+
17521
+ return matrix
17522
+
17523
+ ###################################################################################
17524
+
17525
+ def escore_notes_to_rle_tokens(escore_notes,
17526
+ shift=129,
17527
+ chunk=128,
17528
+ verbose=False
17529
+ ):
17530
+
17531
+ bmatrix = escore_notes_to_binary_matrix(escore_notes)
17532
+
17533
+ enc = encode_matrix_marker_prefixed(bmatrix,
17534
+ verbose=verbose
17535
+ )
17536
+
17537
+ return flatten(enc['rows'])
17538
+
17539
+ ###################################################################################
17540
+
17541
+ def rle_tokens_to_escore_notes(rle_tokens_list,
17542
+ shift=129,
17543
+ chunk=128,
17544
+ return_bmatrix=False,
17545
+ return_enc_dic=False,
17546
+ verbose=False
17547
+ ):
17548
+
17549
+ rows = []
17550
+ row = []
17551
+ row_count = 0
17552
+
17553
+ if rle_tokens_list[0] < shift:
17554
+ row.append(shift+1)
17555
+
17556
+ for t in rle_tokens_list:
17557
+ if t >= shift:
17558
+ if row:
17559
+ rows.append(row)
17560
+ row_count += (row[0]-shift+1)
17561
+
17562
+ row = [t]
17563
+
17564
+ else:
17565
+ row.append(t)
17566
+
17567
+ if row:
17568
+ rows.append(row)
17569
+ row_count += (row[0]-shift+1)
17570
+
17571
+ enc_dic = {}
17572
+
17573
+ enc_dic['shape'] = row_count, chunk
17574
+ enc_dic['rows'] = rows
17575
+
17576
+ if return_enc_dic:
17577
+ return enc_dic
17578
+
17579
+ bmatrix = decode_matrix_marker_prefixed(enc_dic,
17580
+ shift=shift,
17581
+ chunk=chunk,
17582
+ verbose=verbose
17583
+ )
17584
+
17585
+ if return_bmatrix:
17586
+ return bmatrix
17587
+
17588
+ return binary_matrix_to_original_escore_notes(bmatrix)
17589
+
17590
+ ###################################################################################
17591
 
17592
  print('Module loaded!')
17593
  print('=' * 70)