gzip-text-classifier / gzip_classifier.py
knoxel's picture
Upload gzip_classifier.py
fa1acc2 verified
"""
Implementation of "Less is More: Parameter-Free Text Classification with Gzip"
Paper: https://arxiv.org/abs/2212.09410 (Jiang et al., 2022)
Zero parameters. Zero training. CPU only. 15 lines of core logic.
Usage:
python gzip_classifier.py
"""
import gzip
import numpy as np
import time
import json
from collections import Counter
from datasets import load_dataset
from multiprocessing import Pool, cpu_count
# ============================================================
# Configuration
# ============================================================
TRAIN_SAMPLES_PER_CLASS = 500
TEST_SAMPLES = 200
K_VALUES = [1, 2, 3, 5, 7]
RANDOM_SEED = 42
NUM_WORKERS = max(1, cpu_count() - 1)
LABEL_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
def ncd(x1, x2, Cx1=None, Cx2=None):
"""Normalized Compression Distance using gzip.
NCD(x,y) = (C(xy) - min(C(x), C(y))) / max(C(x), C(y))
"""
if Cx1 is None:
Cx1 = len(gzip.compress(x1.encode()))
if Cx2 is None:
Cx2 = len(gzip.compress(x2.encode()))
Cx1x2 = len(gzip.compress(" ".join([x1, x2]).encode()))
return (Cx1x2 - min(Cx1, Cx2)) / max(Cx1, Cx2)
def gzip_classify(test_text, train_texts, train_labels, train_compressed, k=7):
"""Classify text using gzip NCD + kNN."""
Cx1 = len(gzip.compress(test_text.encode()))
distances = []
for j, (x2, Cx2) in enumerate(zip(train_texts, train_compressed)):
d = ncd(test_text, x2, Cx1, Cx2)
distances.append(d)
sorted_idx = np.argsort(distances)
top_k_labels = [train_labels[i] for i in sorted_idx[:k]]
return Counter(top_k_labels).most_common(1)[0][0]
def compute_distances_worker(args):
"""Worker function for parallel distance computation."""
test_idx, x1, Cx1, train_texts_local, train_compressed_local = args
distances = np.zeros(len(train_texts_local))
for j in range(len(train_texts_local)):
x1x2 = " ".join([x1, train_texts_local[j]])
Cx1x2 = len(gzip.compress(x1x2.encode()))
Cx2 = train_compressed_local[j]
distances[j] = (Cx1x2 - min(Cx1, Cx2)) / max(Cx1, Cx2)
return test_idx, distances
if __name__ == "__main__":
print("=" * 60)
print("Gzip + kNN Text Classification (Jiang et al., 2022)")
print("Paper: https://arxiv.org/abs/2212.09410")
print("=" * 60)
# Load dataset
print("Loading AG News...")
dataset = load_dataset("fancyzhx/ag_news")
np.random.seed(RANDOM_SEED)
# Stratified sampling
train_indices = []
for label in range(4):
label_indices = [i for i, l in enumerate(dataset["train"]["label"]) if l == label]
sampled = np.random.choice(label_indices, TRAIN_SAMPLES_PER_CLASS, replace=False)
train_indices.extend(sampled)
np.random.shuffle(train_indices)
train_texts = [dataset["train"]["text"][i] for i in train_indices]
train_labels = np.array([dataset["train"]["label"][i] for i in train_indices])
test_indices = np.random.choice(len(dataset["test"]), TEST_SAMPLES, replace=False)
test_texts = [dataset["test"]["text"][i] for i in test_indices]
test_labels = np.array([dataset["test"]["label"][i] for i in test_indices])
print(f"Train: {len(train_texts)}, Test: {len(test_texts)}")
# Pre-compress
train_compressed = [len(gzip.compress(t.encode())) for t in train_texts]
test_compressed = [len(gzip.compress(t.encode())) for t in test_texts]
# Compute distance matrix
print(f"Computing {len(test_texts)}x{len(train_texts)} NCD matrix...")
t0 = time.time()
distance_matrix = np.zeros((len(test_texts), len(train_texts)))
for i in range(len(test_texts)):
Cx1 = test_compressed[i]
for j in range(len(train_texts)):
x1x2 = " ".join([test_texts[i], train_texts[j]])
Cx1x2 = len(gzip.compress(x1x2.encode()))
Cx2 = train_compressed[j]
distance_matrix[i, j] = (Cx1x2 - min(Cx1, Cx2)) / max(Cx1, Cx2)
if (i + 1) % 20 == 0:
print(f" {i+1}/{len(test_texts)}")
print(f"Done in {time.time() - t0:.1f}s")
# Sweep k values
print("\nResults:")
for k in K_VALUES:
sorted_idx = np.argsort(distance_matrix, axis=1)[:, :k]
preds = np.array([Counter(train_labels[sorted_idx[i]].tolist()).most_common(1)[0][0]
for i in range(len(test_texts))])
acc = np.mean(preds == test_labels)
print(f" k={k}: accuracy={acc:.4f}")