knoxel commited on
Commit
fa1acc2
·
verified ·
1 Parent(s): a7279bf

Upload gzip_classifier.py

Browse files
Files changed (1) hide show
  1. gzip_classifier.py +124 -0
gzip_classifier.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of "Less is More: Parameter-Free Text Classification with Gzip"
3
+ Paper: https://arxiv.org/abs/2212.09410 (Jiang et al., 2022)
4
+
5
+ Zero parameters. Zero training. CPU only. 15 lines of core logic.
6
+
7
+ Usage:
8
+ python gzip_classifier.py
9
+ """
10
+
11
+ import gzip
12
+ import numpy as np
13
+ import time
14
+ import json
15
+ from collections import Counter
16
+ from datasets import load_dataset
17
+ from multiprocessing import Pool, cpu_count
18
+
19
+
20
+ # ============================================================
21
+ # Configuration
22
+ # ============================================================
23
+ TRAIN_SAMPLES_PER_CLASS = 500
24
+ TEST_SAMPLES = 200
25
+ K_VALUES = [1, 2, 3, 5, 7]
26
+ RANDOM_SEED = 42
27
+ NUM_WORKERS = max(1, cpu_count() - 1)
28
+ LABEL_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
29
+
30
+
31
+ def ncd(x1, x2, Cx1=None, Cx2=None):
32
+ """Normalized Compression Distance using gzip.
33
+
34
+ NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))
35
+ """
36
+ if Cx1 is None:
37
+ Cx1 = len(gzip.compress(x1.encode()))
38
+ if Cx2 is None:
39
+ Cx2 = len(gzip.compress(x2.encode()))
40
+ Cx1x2 = len(gzip.compress(" ".join([x1, x2]).encode()))
41
+ return (Cx1x2 - min(Cx1, Cx2)) / max(Cx1, Cx2)
42
+
43
+
44
+ def gzip_classify(test_text, train_texts, train_labels, train_compressed, k=7):
45
+ """Classify text using gzip NCD + kNN."""
46
+ Cx1 = len(gzip.compress(test_text.encode()))
47
+ distances = []
48
+ for j, (x2, Cx2) in enumerate(zip(train_texts, train_compressed)):
49
+ d = ncd(test_text, x2, Cx1, Cx2)
50
+ distances.append(d)
51
+ sorted_idx = np.argsort(distances)
52
+ top_k_labels = [train_labels[i] for i in sorted_idx[:k]]
53
+ return Counter(top_k_labels).most_common(1)[0][0]
54
+
55
+
56
+ def compute_distances_worker(args):
57
+ """Worker function for parallel distance computation."""
58
+ test_idx, x1, Cx1, train_texts_local, train_compressed_local = args
59
+ distances = np.zeros(len(train_texts_local))
60
+ for j in range(len(train_texts_local)):
61
+ x1x2 = " ".join([x1, train_texts_local[j]])
62
+ Cx1x2 = len(gzip.compress(x1x2.encode()))
63
+ Cx2 = train_compressed_local[j]
64
+ distances[j] = (Cx1x2 - min(Cx1, Cx2)) / max(Cx1, Cx2)
65
+ return test_idx, distances
66
+
67
+
68
+ if __name__ == "__main__":
69
+ print("=" * 60)
70
+ print("Gzip + kNN Text Classification (Jiang et al., 2022)")
71
+ print("Paper: https://arxiv.org/abs/2212.09410")
72
+ print("=" * 60)
73
+
74
+ # Load dataset
75
+ print("Loading AG News...")
76
+ dataset = load_dataset("fancyzhx/ag_news")
77
+ np.random.seed(RANDOM_SEED)
78
+
79
+ # Stratified sampling
80
+ train_indices = []
81
+ for label in range(4):
82
+ label_indices = [i for i, l in enumerate(dataset["train"]["label"]) if l == label]
83
+ sampled = np.random.choice(label_indices, TRAIN_SAMPLES_PER_CLASS, replace=False)
84
+ train_indices.extend(sampled)
85
+ np.random.shuffle(train_indices)
86
+
87
+ train_texts = [dataset["train"]["text"][i] for i in train_indices]
88
+ train_labels = np.array([dataset["train"]["label"][i] for i in train_indices])
89
+
90
+ test_indices = np.random.choice(len(dataset["test"]), TEST_SAMPLES, replace=False)
91
+ test_texts = [dataset["test"]["text"][i] for i in test_indices]
92
+ test_labels = np.array([dataset["test"]["label"][i] for i in test_indices])
93
+
94
+ print(f"Train: {len(train_texts)}, Test: {len(test_texts)}")
95
+
96
+ # Pre-compress
97
+ train_compressed = [len(gzip.compress(t.encode())) for t in train_texts]
98
+ test_compressed = [len(gzip.compress(t.encode())) for t in test_texts]
99
+
100
+ # Compute distance matrix
101
+ print(f"Computing {len(test_texts)}x{len(train_texts)} NCD matrix...")
102
+ t0 = time.time()
103
+
104
+ distance_matrix = np.zeros((len(test_texts), len(train_texts)))
105
+ for i in range(len(test_texts)):
106
+ Cx1 = test_compressed[i]
107
+ for j in range(len(train_texts)):
108
+ x1x2 = " ".join([test_texts[i], train_texts[j]])
109
+ Cx1x2 = len(gzip.compress(x1x2.encode()))
110
+ Cx2 = train_compressed[j]
111
+ distance_matrix[i, j] = (Cx1x2 - min(Cx1, Cx2)) / max(Cx1, Cx2)
112
+ if (i + 1) % 20 == 0:
113
+ print(f" {i+1}/{len(test_texts)}")
114
+
115
+ print(f"Done in {time.time() - t0:.1f}s")
116
+
117
+ # Sweep k values
118
+ print("\nResults:")
119
+ for k in K_VALUES:
120
+ sorted_idx = np.argsort(distance_matrix, axis=1)[:, :k]
121
+ preds = np.array([Counter(train_labels[sorted_idx[i]].tolist()).most_common(1)[0][0]
122
+ for i in range(len(test_texts))])
123
+ acc = np.mean(preds == test_labels)
124
+ print(f" k={k}: accuracy={acc:.4f}")