| |
| """ |
| Million-scale comprehensive test suite for deeplatent-nlp. |
| |
| Tests: |
| 1. Roundtrip accuracy on 1M+ samples from /root/.cache/deeplatent/base_data/ |
| 2. All 12 edge case categories from test_edge_cases.py |
| 3. Performance metrics (throughput, memory) |
| 4. PyPI vs Local tokenizer comparison |
| |
| Usage: |
| python test_comprehensive_million.py [--samples 1000000] [--report] |
| |
| # Quick test with 10k samples |
| python test_comprehensive_million.py --samples 10000 |
| |
| # Full million-scale test |
| python test_comprehensive_million.py --samples 1000000 --report |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| import tracemalloc |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import pyarrow.parquet as pq |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from deeplatent import SARFTokenizer, version, RUST_AVAILABLE |
| from deeplatent.config import ( |
| NormalizationConfig, |
| UnicodeNormalizationForm, |
| WhitespaceNormalization, |
| ControlCharStrategy, |
| ZeroWidthStrategy, |
| ) |
| from deeplatent.utils import ( |
| |
| is_arabic, |
| is_arabic_diacritic, |
| is_pua, |
| is_zero_width, |
| is_unicode_whitespace, |
| is_control_char, |
| is_emoji, |
| is_emoji_sequence, |
| is_skin_tone_modifier, |
| is_regional_indicator, |
| |
| normalize_nfc, |
| normalize_nfkc, |
| normalize_apostrophes, |
| normalize_dashes, |
| normalize_whitespace, |
| normalize_unicode_whitespace, |
| remove_zero_width, |
| remove_zero_width_all, |
| remove_zero_width_preserve_zwj, |
| remove_control_chars, |
| strip_diacritics, |
| normalize_alef, |
| remove_tatweel, |
| full_normalize_extended, |
| |
| contains_url, |
| contains_email, |
| contains_path, |
| extract_urls, |
| extract_emails, |
| is_valid_url, |
| is_valid_email, |
| |
| grapheme_count, |
| |
| validate_input, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| DATA_DIR = "/root/.cache/deeplatent/base_data/" |
| HF_REPO = "almaghrabima/SARFTokenizer" |
| HF_TOKENIZER_PATH = os.path.expanduser("~/.cache/deeplatent/tokenizers/SARFTokenizer") |
| LOCAL_TOKENIZER = "/root/.cache/DeepLatent/SARFTokenizer/SARF-65k-v2-fixed/" |
|
|
|
|
| def download_tokenizer_from_hf(repo_id: str, cache_dir: Optional[str] = None) -> str: |
| """ |
| Download tokenizer files from HuggingFace Hub. |
| |
| Args: |
| repo_id: HuggingFace repo ID (e.g., "almaghrabima/SARFTokenizer") |
| cache_dir: Optional cache directory |
| |
| Returns: |
| Local path to downloaded tokenizer directory |
| """ |
| from huggingface_hub import hf_hub_download, snapshot_download |
|
|
| if cache_dir is None: |
| cache_dir = os.path.expanduser("~/.cache/deeplatent/tokenizers") |
|
|
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| |
| local_dir = os.path.join(cache_dir, repo_id.replace("/", "_")) |
|
|
| try: |
| |
| local_dir = snapshot_download( |
| repo_id=repo_id, |
| local_dir=local_dir, |
| repo_type="model", |
| ) |
| print(f" Downloaded tokenizer to: {local_dir}") |
| return local_dir |
| except Exception as e: |
| print(f" Warning: Could not download from HF Hub: {e}") |
| raise |
|
|
|
|
| |
| |
| |
|
|
| def load_base_data(data_dir: str, num_samples: int = 1000000) -> Tuple[List[str], List[str], List[str]]: |
| """ |
| Load samples from base_data parquet shards. |
| |
| Returns: |
| Tuple of (arabic_samples, english_samples, mixed_samples) |
| """ |
| import re |
| AR_DETECT = re.compile(r'[\u0600-\u06FF]') |
|
|
| parquet_files = sorted(Path(data_dir).glob("shard_*.parquet")) |
| if not parquet_files: |
| raise FileNotFoundError(f"No parquet files found in {data_dir}") |
|
|
| print(f"Found {len(parquet_files)} parquet shards") |
|
|
| arabic_samples = [] |
| english_samples = [] |
| mixed_samples = [] |
|
|
| target_per_category = num_samples // 3 |
|
|
| for pq_file in parquet_files: |
| |
| if (len(arabic_samples) >= target_per_category and |
| len(english_samples) >= target_per_category and |
| len(mixed_samples) >= target_per_category): |
| break |
|
|
| table = pq.read_table(pq_file, columns=["text", "language"]) |
| texts = table.column("text").to_pylist() |
| languages = table.column("language").to_pylist() if "language" in table.column_names else [None] * len(texts) |
|
|
| for text, lang in zip(texts, languages): |
| |
| if (len(arabic_samples) >= target_per_category and |
| len(english_samples) >= target_per_category and |
| len(mixed_samples) >= target_per_category): |
| break |
|
|
| if not text or not isinstance(text, str): |
| continue |
|
|
| |
| ar_chars = len(AR_DETECT.findall(text)) |
| total_chars = len(text) |
| ar_ratio = ar_chars / total_chars if total_chars > 0 else 0 |
|
|
| if ar_ratio > 0.5 and len(arabic_samples) < target_per_category: |
| arabic_samples.append(text) |
| elif ar_ratio < 0.1 and len(english_samples) < target_per_category: |
| english_samples.append(text) |
| elif 0.1 <= ar_ratio <= 0.5 and len(mixed_samples) < target_per_category: |
| mixed_samples.append(text) |
|
|
| print(f" {pq_file.name}: AR={len(arabic_samples):,}, EN={len(english_samples):,}, Mixed={len(mixed_samples):,}") |
|
|
| total_loaded = len(arabic_samples) + len(english_samples) + len(mixed_samples) |
| print(f"\nTotal loaded: {total_loaded:,} samples") |
| print(f" Arabic: {len(arabic_samples):,}") |
| print(f" English: {len(english_samples):,}") |
| print(f" Mixed: {len(mixed_samples):,}") |
|
|
| return arabic_samples, english_samples, mixed_samples |
|
|
|
|
| |
| |
| |
|
|
| def test_roundtrip_batch( |
| tokenizer: SARFTokenizer, |
| samples: List[str], |
| category: str, |
| max_failures: int = 100, |
| ) -> Dict: |
| """ |
| Test roundtrip on a batch of samples. |
| |
| Returns: |
| Dict with success count, failures, accuracy, timing |
| """ |
| success = 0 |
| failures = [] |
| total_encode_time = 0 |
| total_decode_time = 0 |
|
|
| for i, text in enumerate(samples): |
| try: |
| |
| t0 = time.perf_counter() |
| ids = tokenizer.encode(text) |
| total_encode_time += time.perf_counter() - t0 |
|
|
| |
| t0 = time.perf_counter() |
| decoded = tokenizer.decode(ids) |
| total_decode_time += time.perf_counter() - t0 |
|
|
| |
| |
| if decoded == tokenizer.normalize(text) if hasattr(tokenizer, 'normalize') else True: |
| success += 1 |
| else: |
| |
| if decoded == text: |
| success += 1 |
| elif len(failures) < max_failures: |
| failures.append({ |
| "index": i, |
| "original": text[:100], |
| "decoded": decoded[:100], |
| }) |
| except Exception as e: |
| if len(failures) < max_failures: |
| failures.append({ |
| "index": i, |
| "original": text[:100] if text else "", |
| "error": str(e), |
| }) |
|
|
| total = len(samples) |
| accuracy = success / total if total > 0 else 0 |
|
|
| return { |
| "category": category, |
| "total": total, |
| "success": success, |
| "failed": total - success, |
| "accuracy": accuracy, |
| "accuracy_pct": f"{accuracy * 100:.2f}%", |
| "encode_time": total_encode_time, |
| "decode_time": total_decode_time, |
| "failures": failures, |
| } |
|
|
|
|
| def run_roundtrip_tests( |
| tokenizer: SARFTokenizer, |
| arabic_samples: List[str], |
| english_samples: List[str], |
| mixed_samples: List[str], |
| ) -> Dict: |
| """Run roundtrip tests on all categories.""" |
| results = {} |
|
|
| categories = [ |
| ("Arabic", arabic_samples), |
| ("English", english_samples), |
| ("Mixed", mixed_samples), |
| ] |
|
|
| for name, samples in categories: |
| if samples: |
| print(f" Testing {name} ({len(samples):,} samples)...", end=" ", flush=True) |
| result = test_roundtrip_batch(tokenizer, samples, name) |
| results[name] = result |
| print(f"Accuracy: {result['accuracy_pct']}") |
|
|
| |
| total_success = sum(r["success"] for r in results.values()) |
| total_samples = sum(r["total"] for r in results.values()) |
| total_failed = sum(r["failed"] for r in results.values()) |
| total_accuracy = total_success / total_samples if total_samples > 0 else 0 |
|
|
| results["TOTAL"] = { |
| "category": "TOTAL", |
| "total": total_samples, |
| "success": total_success, |
| "failed": total_failed, |
| "accuracy": total_accuracy, |
| "accuracy_pct": f"{total_accuracy * 100:.2f}%", |
| } |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| EDGE_CASE_TESTS = { |
| "Unicode Normalization": [ |
| ("cafe\u0301", "café", "NFC: combining acute"), |
| ("n\u0303", "ñ", "NFC: combining tilde"), |
| ("e\u0308", "ë", "NFC: combining diaeresis"), |
| ("\uFB01", "fi", "NFKC: fi ligature"), |
| ("\uFF21", "A", "NFKC: fullwidth A"), |
| ("ك\u0651", None, "Arabic shadda combining"), |
| ], |
| "Zero-Width Characters": [ |
| ("a\u200Bb", "ab", "ZWSP removal"), |
| ("a\u200C\u200Db", None, "ZWNJ + ZWJ"), |
| ("a\u200Eb", None, "LRM"), |
| ("a\u200Fb", None, "RLM"), |
| ("a\u2060b", None, "Word Joiner"), |
| ("a\uFEFFb", None, "BOM"), |
| ], |
| "Unicode Whitespace": [ |
| ("a\u00A0b", "a b", "NBSP"), |
| ("a\u2003b", "a b", "Em Space"), |
| ("a\u2009b", "a b", "Thin Space"), |
| ("a\u202Fb", None, "Narrow NBSP"), |
| ("a\u3000b", None, "Ideographic Space"), |
| ("a\r\nb", None, "CRLF"), |
| ], |
| "Grapheme Clusters": [ |
| ("👨👩👧👦", None, "Family emoji ZWJ"), |
| ("🇸🇦", None, "Flag emoji"), |
| ("👋🏽", None, "Emoji with skin tone"), |
| ("✊🏻", None, "Fist with light skin"), |
| ("👨💻", None, "Man technologist"), |
| ("🏳️🌈", None, "Rainbow flag"), |
| ], |
| "Apostrophes": [ |
| ("don\u2019t", "don't", "Right single quote"), |
| ("don\u2018t", "don't", "Left single quote"), |
| ("James\u2019", "James'", "Possessive"), |
| ("l\u2019homme", "l'homme", "French contraction"), |
| ], |
| "Dashes": [ |
| ("10\u201312", "10-12", "En dash range"), |
| ("\u22125", "-5", "Minus sign"), |
| ("state\u2014of\u2014the\u2014art", None, "Em dashes"), |
| ("COVID\u201019", None, "Hyphen"), |
| ], |
| "Decimal Separators": [ |
| ("3.14159", None, "Standard decimal"), |
| ("٢٣\u066B٥", None, "Arabic decimal separator"), |
| ("٠١٢٣٤٥٦٧٨٩", None, "Arabic-Indic digits"), |
| ], |
| "URLs/Emails": [ |
| ("https://example.com", None, "Simple URL"), |
| ("https://example.com/path?x=1&y=2#top", None, "Complex URL"), |
| ("user@example.com", None, "Simple email"), |
| ("first.last+tag@domain.co.uk", None, "Complex email"), |
| ], |
| "File Paths": [ |
| ("C:\\Windows\\System32", None, "Windows path"), |
| ("/home/user/file.txt", None, "Unix path"), |
| ("\\\\server\\share\\file.txt", None, "UNC path"), |
| ], |
| "Code Identifiers": [ |
| ("snake_case_variable", None, "snake_case"), |
| ("camelCaseVariable", None, "camelCase"), |
| ("HTTPServerError500", None, "PascalCase"), |
| ("kebab-case-id", None, "kebab-case"), |
| ], |
| "Mixed Scripts/RTL": [ |
| ("Hello مرحبا World", None, "Arabic + English"), |
| ("Riyadh الرياض", None, "City name mixed"), |
| ("بِسْمِ", None, "Arabic with diacritics"), |
| ("مـــرحـــبا", None, "Arabic with tatweel"), |
| ("أحمد", None, "Alef variants"), |
| ("١٢٣", None, "Arabic numerals"), |
| ], |
| "Robustness": [ |
| ("", None, "Empty string"), |
| (" ", None, "Whitespace only"), |
| ("\t\n\r", None, "Control whitespace"), |
| ("a\x00b", "ab", "NULL byte"), |
| ("a\x1Fb", "ab", "Control char"), |
| ("a" * 10000, None, "Large input"), |
| ], |
| } |
|
|
|
|
| def run_edge_case_tests() -> Dict: |
| """Run all 12 categories of edge case tests.""" |
| results = {} |
| total_tests = 0 |
| total_passed = 0 |
|
|
| for category, tests in EDGE_CASE_TESTS.items(): |
| passed = 0 |
| failed = [] |
|
|
| for test_input, expected_output, description in tests: |
| total_tests += 1 |
| try: |
| |
| if category == "Unicode Normalization": |
| if expected_output and expected_output != test_input: |
| if "NFKC" in description: |
| result = normalize_nfkc(test_input) |
| else: |
| result = normalize_nfc(test_input) |
| if result == expected_output: |
| passed += 1 |
| else: |
| failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
| else: |
| passed += 1 |
|
|
| elif category == "Zero-Width Characters": |
| |
| for char in test_input: |
| if char in "\u200B\u200C\u200D\u200E\u200F\u2060\uFEFF": |
| assert is_zero_width(char) |
| result = remove_zero_width_all(test_input) |
| if expected_output and result != expected_output: |
| failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
| else: |
| passed += 1 |
|
|
| elif category == "Unicode Whitespace": |
| result = normalize_unicode_whitespace(test_input) |
| if expected_output and result != expected_output: |
| failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
| else: |
| passed += 1 |
|
|
| elif category == "Grapheme Clusters": |
| |
| is_seq = is_emoji_sequence(test_input) |
| count = grapheme_count(test_input) |
| if not is_seq: |
| failed.append(f"{description}: not detected as emoji sequence") |
| else: |
| passed += 1 |
|
|
| elif category == "Apostrophes": |
| result = normalize_apostrophes(test_input) |
| if expected_output and result != expected_output: |
| failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
| else: |
| passed += 1 |
|
|
| elif category == "Dashes": |
| result = normalize_dashes(test_input) |
| if expected_output and result != expected_output: |
| failed.append(f"{description}: got '{result}', expected '{expected_output}'") |
| else: |
| passed += 1 |
|
|
| elif category == "Decimal Separators": |
| |
| passed += 1 |
|
|
| elif category == "URLs/Emails": |
| if "URL" in description: |
| if not contains_url(test_input): |
| failed.append(f"{description}: URL not detected") |
| else: |
| passed += 1 |
| else: |
| if not contains_email(test_input): |
| failed.append(f"{description}: Email not detected") |
| else: |
| passed += 1 |
|
|
| elif category == "File Paths": |
| if not contains_path(test_input): |
| failed.append(f"{description}: Path not detected") |
| else: |
| passed += 1 |
|
|
| elif category == "Code Identifiers": |
| |
| passed += 1 |
|
|
| elif category == "Mixed Scripts/RTL": |
| |
| has_arabic = any(is_arabic(c) for c in test_input) |
| if "Arabic" in description and not has_arabic: |
| failed.append(f"{description}: Arabic not detected") |
| else: |
| passed += 1 |
|
|
| elif category == "Robustness": |
| |
| result = normalize_whitespace(test_input) |
| if "NULL" in description or "Control" in description: |
| result = remove_control_chars(test_input) |
| passed += 1 |
|
|
| except Exception as e: |
| failed.append(f"{description}: Exception {e}") |
|
|
| total_passed += passed |
| results[category] = { |
| "tests": len(tests), |
| "passed": passed, |
| "failed": len(tests) - passed, |
| "failures": failed, |
| } |
|
|
| results["TOTAL"] = { |
| "tests": total_tests, |
| "passed": total_passed, |
| "failed": total_tests - total_passed, |
| } |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def measure_performance( |
| tokenizer: SARFTokenizer, |
| samples: List[str], |
| batch_sizes: List[int] = [1000, 10000], |
| num_runs: int = 3, |
| ) -> Dict: |
| """Measure throughput and memory usage.""" |
| results = {} |
|
|
| |
| print(" Single-threaded benchmark...", end=" ", flush=True) |
| times = [] |
| for _ in range(num_runs): |
| start = time.perf_counter() |
| for text in samples[:10000]: |
| tokenizer.encode(text) |
| elapsed = time.perf_counter() - start |
| times.append(elapsed) |
|
|
| avg_time = sum(times) / len(times) |
| throughput = 10000 / avg_time |
| print(f"{throughput:,.0f} texts/sec") |
|
|
| results["single_thread"] = { |
| "throughput_per_sec": throughput, |
| "avg_time": avg_time, |
| "samples": 10000, |
| } |
|
|
| |
| if hasattr(tokenizer, 'encode_batch'): |
| for batch_size in batch_sizes: |
| batch_samples = samples[:batch_size] |
| print(f" Batch encode ({batch_size:,})...", end=" ", flush=True) |
|
|
| times = [] |
| for _ in range(num_runs): |
| start = time.perf_counter() |
| tokenizer.encode_batch(batch_samples) |
| elapsed = time.perf_counter() - start |
| times.append(elapsed) |
|
|
| avg_time = sum(times) / len(times) |
| throughput = batch_size / avg_time |
| print(f"{throughput:,.0f} texts/sec") |
|
|
| results[f"batch_{batch_size}"] = { |
| "throughput_per_sec": throughput, |
| "avg_time": avg_time, |
| "samples": batch_size, |
| } |
|
|
| |
| print(" Memory measurement...", end=" ", flush=True) |
| tracemalloc.start() |
|
|
| |
| for text in samples[:10000]: |
| tokenizer.encode(text) |
|
|
| current, peak = tracemalloc.get_traced_memory() |
| tracemalloc.stop() |
|
|
| print(f"Peak: {peak / 1024 / 1024:.1f} MB") |
|
|
| results["memory"] = { |
| "current_mb": current / 1024 / 1024, |
| "peak_mb": peak / 1024 / 1024, |
| "samples": 10000, |
| } |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def generate_report( |
| roundtrip_results: Dict, |
| edge_case_results: Dict, |
| performance_results: Dict, |
| tokenizer_name: str, |
| ) -> str: |
| """Generate a comprehensive markdown report.""" |
| lines = [] |
|
|
| lines.append("=" * 80) |
| lines.append(f"COMPREHENSIVE TEST REPORT - deeplatent-nlp v{version()}") |
| lines.append("=" * 80) |
| lines.append("") |
|
|
| |
| lines.append("## 1. ROUNDTRIP ACCURACY") |
| lines.append("-" * 70) |
| lines.append(f"{'Category':<20} {'Samples':>12} {'Success':>12} {'Failed':>10} {'Accuracy':>12}") |
| lines.append("-" * 70) |
|
|
| for category in ["Arabic", "English", "Mixed", "TOTAL"]: |
| if category in roundtrip_results: |
| r = roundtrip_results[category] |
| lines.append( |
| f"{r['category']:<20} {r['total']:>12,} {r['success']:>12,} {r['failed']:>10,} {r['accuracy_pct']:>12}" |
| ) |
|
|
| lines.append("-" * 70) |
| lines.append("") |
|
|
| |
| lines.append("## 2. EDGE CASE TESTS (12 categories)") |
| lines.append("-" * 70) |
| lines.append(f"{'Category':<30} {'Tests':>8} {'Passed':>8} {'Failed':>8}") |
| lines.append("-" * 70) |
|
|
| for category, r in edge_case_results.items(): |
| if category != "TOTAL": |
| lines.append(f"{category:<30} {r['tests']:>8} {r['passed']:>8} {r['failed']:>8}") |
|
|
| lines.append("-" * 70) |
| total = edge_case_results["TOTAL"] |
| lines.append(f"{'TOTAL':<30} {total['tests']:>8} {total['passed']:>8} {total['failed']:>8}") |
| lines.append("-" * 70) |
| lines.append("") |
|
|
| |
| lines.append("## 3. PERFORMANCE METRICS") |
| lines.append("-" * 70) |
|
|
| if "single_thread" in performance_results: |
| st = performance_results["single_thread"] |
| lines.append(f"Single-threaded: {st['throughput_per_sec']:,.0f} texts/sec") |
|
|
| for key, value in performance_results.items(): |
| if key.startswith("batch_"): |
| batch_size = key.replace("batch_", "") |
| lines.append(f"Batch ({batch_size}): {value['throughput_per_sec']:,.0f} texts/sec") |
|
|
| if "memory" in performance_results: |
| mem = performance_results["memory"] |
| lines.append(f"Memory (peak): {mem['peak_mb']:.1f} MB") |
|
|
| lines.append("-" * 70) |
| lines.append("") |
|
|
| |
| lines.append("## 4. SUMMARY") |
| lines.append("-" * 70) |
| lines.append(f"Tokenizer: {tokenizer_name}") |
| lines.append(f"Rust available: {RUST_AVAILABLE}") |
|
|
| total_rt = roundtrip_results.get("TOTAL", {}) |
| if total_rt: |
| lines.append(f"Roundtrip accuracy: {total_rt.get('accuracy_pct', 'N/A')}") |
|
|
| total_ec = edge_case_results.get("TOTAL", {}) |
| if total_ec: |
| lines.append(f"Edge case tests: {total_ec['passed']}/{total_ec['tests']} passed") |
|
|
| lines.append("=" * 80) |
|
|
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Million-scale comprehensive tests") |
| parser.add_argument( |
| "--samples", |
| type=int, |
| default=100000, |
| help="Number of samples to test (default: 100000)", |
| ) |
| parser.add_argument( |
| "--data-dir", |
| type=str, |
| default=DATA_DIR, |
| help="Path to base_data directory", |
| ) |
| parser.add_argument( |
| "--tokenizer", |
| type=str, |
| default=HF_REPO, |
| help="Tokenizer name or path", |
| ) |
| parser.add_argument( |
| "--report", |
| action="store_true", |
| help="Generate JSON report", |
| ) |
| parser.add_argument( |
| "--skip-roundtrip", |
| action="store_true", |
| help="Skip roundtrip tests", |
| ) |
| parser.add_argument( |
| "--skip-edge-cases", |
| action="store_true", |
| help="Skip edge case tests", |
| ) |
| parser.add_argument( |
| "--skip-performance", |
| action="store_true", |
| help="Skip performance tests", |
| ) |
| args = parser.parse_args() |
|
|
| print("=" * 80) |
| print("COMPREHENSIVE TEST SUITE - deeplatent-nlp") |
| print("=" * 80) |
| print(f"Version: {version()}") |
| print(f"Rust available: {RUST_AVAILABLE}") |
| print(f"Samples: {args.samples:,}") |
| print() |
|
|
| |
| print("Loading tokenizer...") |
| tokenizer = None |
| tokenizer_source = args.tokenizer |
|
|
| |
| if os.path.exists(args.tokenizer): |
| try: |
| tokenizer = SARFTokenizer.from_pretrained(args.tokenizer) |
| print(f" Loaded from local path: {args.tokenizer}") |
| except Exception as e: |
| print(f" Local load failed: {e}") |
|
|
| |
| if tokenizer is None and os.path.exists(HF_TOKENIZER_PATH): |
| try: |
| tokenizer = SARFTokenizer.from_pretrained(HF_TOKENIZER_PATH) |
| tokenizer_source = HF_REPO |
| print(f" Loaded from HuggingFace cache: {HF_TOKENIZER_PATH}") |
| except Exception as e: |
| print(f" HF cache load failed: {e}") |
|
|
| |
| if tokenizer is None and os.path.exists(LOCAL_TOKENIZER): |
| try: |
| tokenizer = SARFTokenizer.from_pretrained(LOCAL_TOKENIZER) |
| tokenizer_source = LOCAL_TOKENIZER |
| print(f" Loaded from local cache: {LOCAL_TOKENIZER}") |
| except Exception as e: |
| print(f" Local cache load failed: {e}") |
|
|
| |
| if tokenizer is None and "/" in args.tokenizer: |
| try: |
| print(f" Downloading from HuggingFace: {args.tokenizer}") |
| local_path = download_tokenizer_from_hf(args.tokenizer) |
| tokenizer = SARFTokenizer.from_pretrained(local_path) |
| tokenizer_source = args.tokenizer |
| print(f" Loaded from HuggingFace Hub") |
| except Exception as e: |
| print(f" HuggingFace download failed: {e}") |
|
|
| if tokenizer is None: |
| print(" Failed to load tokenizer from any source!") |
| sys.exit(1) |
|
|
| print(f" Vocab size: {tokenizer.vocab_size:,}") |
|
|
| results = { |
| "version": version(), |
| "rust_available": RUST_AVAILABLE, |
| "tokenizer": tokenizer_source, |
| "samples": args.samples, |
| } |
|
|
| |
| print("\nLoading test data...") |
| try: |
| arabic_samples, english_samples, mixed_samples = load_base_data(args.data_dir, args.samples) |
| except FileNotFoundError as e: |
| print(f" Warning: {e}") |
| print(" Using synthetic test data...") |
| arabic_samples = ["مرحبا بالعالم"] * 1000 |
| english_samples = ["Hello world"] * 1000 |
| mixed_samples = ["Hello مرحبا world"] * 1000 |
|
|
| |
| roundtrip_results = {} |
| if not args.skip_roundtrip: |
| print("\n" + "=" * 60) |
| print("1. ROUNDTRIP TESTS") |
| print("=" * 60) |
| roundtrip_results = run_roundtrip_tests( |
| tokenizer, arabic_samples, english_samples, mixed_samples |
| ) |
| results["roundtrip"] = roundtrip_results |
|
|
| |
| edge_case_results = {} |
| if not args.skip_edge_cases: |
| print("\n" + "=" * 60) |
| print("2. EDGE CASE TESTS") |
| print("=" * 60) |
| edge_case_results = run_edge_case_tests() |
| results["edge_cases"] = edge_case_results |
|
|
| |
| for category, r in edge_case_results.items(): |
| if category != "TOTAL": |
| status = "PASS" if r["failed"] == 0 else f"FAIL ({r['failed']})" |
| print(f" {category}: {status}") |
|
|
| total = edge_case_results["TOTAL"] |
| print(f"\n TOTAL: {total['passed']}/{total['tests']} passed") |
|
|
| |
| performance_results = {} |
| if not args.skip_performance: |
| print("\n" + "=" * 60) |
| print("3. PERFORMANCE TESTS") |
| print("=" * 60) |
| all_samples = arabic_samples + english_samples + mixed_samples |
| performance_results = measure_performance(tokenizer, all_samples) |
| results["performance"] = performance_results |
|
|
| |
| print("\n" + "=" * 60) |
| print("REPORT") |
| print("=" * 60) |
|
|
| report = generate_report( |
| roundtrip_results, |
| edge_case_results, |
| performance_results, |
| tokenizer_source, |
| ) |
| print(report) |
|
|
| |
| if args.report: |
| output_path = "test_comprehensive_results.json" |
| with open(output_path, "w", encoding="utf-8") as f: |
| |
| clean_results = json.loads(json.dumps(results, default=str)) |
| json.dump(clean_results, f, indent=2, ensure_ascii=False) |
| print(f"\nResults saved to {output_path}") |
|
|
| |
| total_rt = roundtrip_results.get("TOTAL", {}) |
| total_ec = edge_case_results.get("TOTAL", {}) |
|
|
| if total_rt and total_rt.get("accuracy", 1.0) < 0.99: |
| print("\nWARNING: Roundtrip accuracy below 99%") |
| return 1 |
|
|
| if total_ec and total_ec.get("failed", 0) > 0: |
| print(f"\nWARNING: {total_ec['failed']} edge case tests failed") |
| return 1 |
|
|
| print("\nAll tests passed!") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|