import yaml import json import os import shutil import argparse from colorama import init, Fore, Style init() def load_json(path): if not os.path.exists(path): return {} try: with open(path, 'r', encoding='utf-8') as f: return json.load(f) except Exception: return {} def save_json(path, data): with open(path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2) def main(): parser = argparse.ArgumentParser(description="Patch missing EOS IDs in generation_config.json") parser.add_argument("config", help="Path to the mergekit yaml config file") args = parser.parse_args() print(f"{Fore.CYAN}--- GENERATION CONFIG PATCHER ---{Style.RESET_ALL}") # 1. Load Config with open(args.config, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) base_model_path = config.get('base_model') if not base_model_path: print("No base_model found.") return # 2. Get Target EOS ID from Base Model print(f"Reading Base Model: {os.path.basename(base_model_path)}") base_gen_path = os.path.join(base_model_path, "generation_config.json") base_gen = load_json(base_gen_path) target_eos_id = base_gen.get("eos_token_id") if target_eos_id is None: print(f"{Fore.RED}CRITICAL: Base model lacks eos_token_id. Cannot patch.{Style.RESET_ALL}") return print(f"Target EOS ID is: {Fore.GREEN}{target_eos_id}{Style.RESET_ALL}") print("-" * 60) # 3. Iterate and Patch models = [m['model'] for m in config.get('models', []) if isinstance(m, dict)] patched_count = 0 for model_path in models: model_name = os.path.basename(model_path).replace("!models--", "") gen_path = os.path.join(model_path, "generation_config.json") # Load or create empty dict data = load_json(gen_path) current_id = data.get("eos_token_id") # Logic: Only patch if MISSING. # If it exists but is different (e.g. 999), we DO NOT touch it (that's a real mismatch). if current_id is None: print(f"Patching {model_name}...") # Backup first if os.path.exists(gen_path): shutil.copy(gen_path, gen_path + ".bak") # Apply Patch data["eos_token_id"] = target_eos_id # Ensure other basics exist if file was empty if "bos_token_id" not in data: data["bos_token_id"] = 1 # Standard Mistral assumption save_json(gen_path, data) print(f" {Fore.GREEN}-> Fixed: Added eos_token_id: {target_eos_id}{Style.RESET_ALL}") patched_count += 1 elif str(current_id) != str(target_eos_id): print(f"Skipping {model_name}: Has ID {current_id} (Mismatch, not missing)") else: # Already matches, do nothing pass print("-" * 60) print(f"Operation Complete. Patched {patched_count} models.") print("Run eos_scanner.py again to verify results.") if __name__ == "__main__": main()