| """
|
| COGNITIVE-CORE: Checkpoint Loading & Key Remapping
|
| ===================================================
|
|
|
| This module provides robust checkpoint loading with automatic key remapping
|
| to handle different checkpoint formats (with/without 'model.' prefix, etc.)
|
|
|
| Copyright © 2026 Mike Amega (Logo) - Ame Web Studio
|
| License: Proprietary - All Rights Reserved
|
| """
|
|
|
| import re
|
| from typing import Dict, Set, Optional
|
| import torch
|
|
|
|
|
| def remap_checkpoint_keys(
|
| checkpoint_state_dict: Dict[str, torch.Tensor],
|
| model_state_dict: Dict[str, torch.Tensor],
|
| verbose: bool = False,
|
| ) -> Dict[str, torch.Tensor]:
|
| """
|
| Remappe automatiquement les clés du checkpoint pour correspondre au modèle.
|
|
|
| Gère les scénarios suivants:
|
| 1. Checkpoint a préfixe 'model.' mais modèle n'en a pas → retirer préfixe
|
| 2. Checkpoint n'a pas préfixe 'model.' mais modèle en a → ajouter préfixe
|
| 3. Autres préfixes personnalisés
|
|
|
| Args:
|
| checkpoint_state_dict: État du checkpoint chargé
|
| model_state_dict: État du modèle cible
|
| verbose: Afficher les détails du remappage
|
|
|
| Returns:
|
| Dict remappé compatible avec le modèle
|
| """
|
| model_keys = set(model_state_dict.keys())
|
| checkpoint_keys = set(checkpoint_state_dict.keys())
|
|
|
|
|
| matching = model_keys & checkpoint_keys
|
| if len(matching) >= len(checkpoint_keys) * 0.9:
|
| if verbose:
|
| print(
|
| f"✅ Checkpoint compatible: {len(matching)}/{len(checkpoint_keys)} clés correspondent"
|
| )
|
| return checkpoint_state_dict
|
|
|
|
|
| strategies = [
|
| ("remove_model_prefix", _remove_prefix, "model."),
|
| ("add_model_prefix", _add_prefix, "model."),
|
| ("remove_backbone_prefix", _remove_prefix, "backbone."),
|
| ("remove_encoder_prefix", _remove_prefix, "encoder."),
|
| ]
|
|
|
| best_strategy = None
|
| best_match_count = len(matching)
|
| best_result = checkpoint_state_dict
|
|
|
| for name, func, prefix in strategies:
|
| remapped = func(checkpoint_state_dict, prefix)
|
| match_count = len(model_keys & set(remapped.keys()))
|
|
|
| if match_count > best_match_count:
|
| best_match_count = match_count
|
| best_strategy = name
|
| best_result = remapped
|
|
|
| if verbose and best_strategy:
|
| print(f"🔄 Stratégie appliquée: {best_strategy}")
|
| print(f" Clés correspondantes: {best_match_count}/{len(checkpoint_keys)}")
|
|
|
|
|
| if best_match_count < len(checkpoint_keys) * 0.5:
|
| best_result = _smart_key_mapping(checkpoint_state_dict, model_keys)
|
| if verbose:
|
| final_match = len(model_keys & set(best_result.keys()))
|
| print(
|
| f"🧠 Remappage intelligent: {final_match}/{len(checkpoint_keys)} clés"
|
| )
|
|
|
| return best_result
|
|
|
|
|
| def _remove_prefix(state_dict: Dict, prefix: str) -> Dict:
|
| """Retirer un préfixe de toutes les clés."""
|
| return {
|
| (k[len(prefix) :] if k.startswith(prefix) else k): v
|
| for k, v in state_dict.items()
|
| }
|
|
|
|
|
| def _add_prefix(state_dict: Dict, prefix: str) -> Dict:
|
| """Ajouter un préfixe à toutes les clés."""
|
| return {f"{prefix}{k}": v for k, v in state_dict.items()}
|
|
|
|
|
| def _smart_key_mapping(
|
| checkpoint_dict: Dict[str, torch.Tensor], model_keys: Set[str]
|
| ) -> Dict[str, torch.Tensor]:
|
| """
|
| Mapping intelligent clé par clé basé sur les suffixes et patterns.
|
| """
|
| result = {}
|
| model_keys_list = list(model_keys)
|
|
|
| for ckpt_key, value in checkpoint_dict.items():
|
|
|
| if ckpt_key in model_keys:
|
| result[ckpt_key] = value
|
| continue
|
|
|
|
|
| with_prefix = f"model.{ckpt_key}"
|
| if with_prefix in model_keys:
|
| result[with_prefix] = value
|
| continue
|
|
|
|
|
| if ckpt_key.startswith("model."):
|
| without_prefix = ckpt_key[6:]
|
| if without_prefix in model_keys:
|
| result[without_prefix] = value
|
| continue
|
|
|
|
|
| ckpt_suffix = ckpt_key.split(".")[-1]
|
| ckpt_base = ".".join(ckpt_key.split(".")[:-1])
|
|
|
| for model_key in model_keys_list:
|
| if model_key.endswith(ckpt_suffix):
|
| model_base = ".".join(model_key.split(".")[:-1])
|
|
|
| if _keys_similar(ckpt_base, model_base):
|
| result[model_key] = value
|
| break
|
| else:
|
|
|
| result[ckpt_key] = value
|
|
|
| return result
|
|
|
|
|
| def _keys_similar(key1: str, key2: str) -> bool:
|
| """Vérifier si deux clés sont structurellement similaires."""
|
| parts1 = key1.split(".")
|
| parts2 = key2.split(".")
|
|
|
|
|
| if len(parts1) != len(parts2):
|
| return False
|
|
|
|
|
| matches = sum(
|
| 1 for p1, p2 in zip(parts1, parts2) if p1 == p2 or p1.isdigit() and p2.isdigit()
|
| )
|
| return matches >= len(parts1) * 0.7
|
|
|
|
|
| def validate_checkpoint(
|
| checkpoint_state_dict: Dict[str, torch.Tensor],
|
| model_state_dict: Dict[str, torch.Tensor],
|
| strict: bool = False,
|
| ) -> Dict[str, any]:
|
| """
|
| Valider qu'un checkpoint est compatible avec un modèle.
|
|
|
| Returns:
|
| Dict avec:
|
| - valid: bool
|
| - missing_keys: clés manquantes dans checkpoint
|
| - unexpected_keys: clés inattendues dans checkpoint
|
| - size_mismatches: clés avec tailles incompatibles
|
| """
|
| model_keys = set(model_state_dict.keys())
|
| ckpt_keys = set(checkpoint_state_dict.keys())
|
|
|
| missing = model_keys - ckpt_keys
|
| unexpected = ckpt_keys - model_keys
|
|
|
|
|
| size_mismatches = []
|
| for key in model_keys & ckpt_keys:
|
| model_shape = model_state_dict[key].shape
|
| ckpt_shape = checkpoint_state_dict[key].shape
|
| if model_shape != ckpt_shape:
|
| size_mismatches.append(
|
| {"key": key, "model_shape": model_shape, "checkpoint_shape": ckpt_shape}
|
| )
|
|
|
| valid = len(missing) == 0 and len(size_mismatches) == 0
|
| if not strict:
|
| valid = len(size_mismatches) == 0 and len(missing) < len(model_keys) * 0.1
|
|
|
| return {
|
| "valid": valid,
|
| "missing_keys": list(missing),
|
| "unexpected_keys": list(unexpected),
|
| "size_mismatches": size_mismatches,
|
| "matched_keys": len(model_keys & ckpt_keys),
|
| "total_model_keys": len(model_keys),
|
| }
|
|
|
|
|
| def save_cognitive_checkpoint(
|
| model,
|
| path: str,
|
| include_optimizer: bool = False,
|
| optimizer=None,
|
| extra_state: Optional[Dict] = None,
|
| ):
|
| """
|
| Sauvegarder un checkpoint de modèle cognitif.
|
|
|
| Args:
|
| model: Le modèle à sauvegarder
|
| path: Chemin de sauvegarde
|
| include_optimizer: Inclure l'état de l'optimiseur
|
| optimizer: L'optimiseur (si include_optimizer=True)
|
| extra_state: État additionnel à sauvegarder
|
| """
|
| checkpoint = {
|
| "model_state_dict": model.state_dict(),
|
| "config": model.config.to_dict() if hasattr(model, "config") else {},
|
| }
|
|
|
| if include_optimizer and optimizer is not None:
|
| checkpoint["optimizer_state_dict"] = optimizer.state_dict()
|
|
|
|
|
| if hasattr(model, "get_cognitive_state"):
|
| checkpoint["cognitive_state"] = model.get_cognitive_state()
|
|
|
| if extra_state:
|
| checkpoint["extra_state"] = extra_state
|
|
|
| torch.save(checkpoint, path)
|
| print(f"✅ Checkpoint sauvegardé: {path}")
|
|
|
|
|
| def load_cognitive_checkpoint(
|
| model, path: str, strict: bool = False, verbose: bool = True
|
| ) -> Dict:
|
| """
|
| Charger un checkpoint dans un modèle cognitif avec remappage automatique.
|
|
|
| Args:
|
| model: Le modèle cible
|
| path: Chemin du checkpoint
|
| strict: Mode strict (erreur si clés manquantes)
|
| verbose: Afficher les détails
|
|
|
| Returns:
|
| Dict avec informations de chargement
|
| """
|
| checkpoint = torch.load(path, map_location="cpu")
|
|
|
|
|
| if "model_state_dict" in checkpoint:
|
| state_dict = checkpoint["model_state_dict"]
|
| elif "state_dict" in checkpoint:
|
| state_dict = checkpoint["state_dict"]
|
| else:
|
| state_dict = checkpoint
|
|
|
|
|
| remapped = remap_checkpoint_keys(state_dict, model.state_dict(), verbose=verbose)
|
|
|
|
|
| validation = validate_checkpoint(remapped, model.state_dict(), strict=strict)
|
|
|
| if verbose:
|
| print(
|
| f"📊 Clés chargées: {validation['matched_keys']}/{validation['total_model_keys']}"
|
| )
|
| if validation["missing_keys"]:
|
| print(f"⚠️ Clés manquantes: {len(validation['missing_keys'])}")
|
| if validation["size_mismatches"]:
|
| print(f"⚠️ Tailles incompatibles: {len(validation['size_mismatches'])}")
|
|
|
|
|
| model.load_state_dict(remapped, strict=False)
|
|
|
|
|
| if "cognitive_state" in checkpoint and hasattr(model, "reset_cognitive_state"):
|
|
|
| pass
|
|
|
| if verbose:
|
| print("✅ Checkpoint chargé avec succès")
|
|
|
| return {
|
| "validation": validation,
|
| "config": checkpoint.get("config", {}),
|
| "extra_state": checkpoint.get("extra_state", {}),
|
| }
|
|
|