import torch import json from safetensors.torch import load_file from id_mapper import UnifiedIdMapper from modeling_llama_edge import LlamaEdgeForCausalLM from configuration_llama_edge import LlamaEdgeConfig class ModelWrapper: def __init__(self, mapper_path, model_path, device="cuda"): # Load Mapper print(f"Loading mapper from {mapper_path}...") self.mapper = UnifiedIdMapper.from_file(mapper_path) # Initialize Empty Model print("Initializing model...") config = LlamaEdgeConfig() # Use defaults or load from file if exists self.model = LlamaEdgeForCausalLM(config) # Load Weights print(f"Loading weights from {model_path}...") state_dict = load_file(model_path, device="cpu") self.model.load_state_dict(state_dict) # Set device if device == "cuda" and not torch.cuda.is_available(): print("CUDA not available, switching to CPU.") self.device = torch.device("cpu") elif device == "mps": # Handle MPS explicitly if requested or available self.device = torch.device("mps") else: self.device = torch.device(device) print(f"Moving model to {self.device}...") self.model.to(self.device) self.model.eval() def predict(self, old_ids_context): """ Args: old_ids_context: List of old IDs defining the context. Returns: sorted_predictions: List of (prob, old_id, label) sorted by probability descending. """ # 1. Convert context list of old IDs to new IDs input_ids = [] for old_id in old_ids_context: # We assume the input old_ids exist in the mapper new_id, _ = self.mapper.map_old_id(old_id) input_ids.append(new_id) # 2. Run inference # Create tensor on result device (batch size = 1) model_input = torch.tensor([input_ids], dtype=torch.long, device=self.device) with torch.no_grad(): logits = self.model(model_input) # Get logits for the last token in the sequence last_token_logits = logits[0, -1, :] probs = torch.softmax(last_token_logits, dim=-1) # 3. Sort by probability descending sorted_probs, sorted_indices = torch.sort(probs, descending=True) sorted_probs = sorted_probs.tolist() sorted_indices = sorted_indices.tolist() # These indices are the new_ids # 4. Create result list with mapping applied results = [] for prob, new_id in zip(sorted_probs, sorted_indices): try: # map_new_id returns (old_id, is_edge) old_id, _ = self.mapper.map_new_id(new_id) label = self.mapper.label_from_new_id(new_id) results.append((prob, old_id, label)) except KeyError: # Handle indices not in mapper (e.g., padding tokens) results.append((prob, -1, "")) return results def main(): # Define paths mapper_path = "unified_id_mapper.json" model_path = "model.safetensors" # Check for device availability if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" print(f"Using device: {device}") # Instantiate the wrapper wrapper = ModelWrapper(mapper_path, model_path, device=device) # Input instance take from first line of the training data input_ids = [108, 112, 117, 234, 421, 582, 601, 608, 940, 941, 948, 1008, 1009, 1076, 1094, 1095, 1125, 1188, 1251, 1275, 1365, 1415, 1522, 1687, 1948, 1977, 2025, 47178924, 47185647] target_edge_id = 47182521 target_edge_label = "/people/person/place_of_birth" predictions = wrapper.predict(input_ids) print(f"Input old IDs: {input_ids}") print(f"Target edge old ID: {target_edge_id}, Label: {target_edge_label}") print("Top 10 Predictions:") for rank, (prob, pred_old_id, pred_label) in enumerate(predictions[:10], start=1): print(f" Rank {rank}: Old ID {pred_old_id}, Label: {pred_label}, Probability: {prob:.6f}") # Check if the target is in top 10 top_10_old_ids = [pred_old_id for _, pred_old_id, _ in predictions[:10]] if target_edge_id in top_10_old_ids: print(f"Target edge old ID {target_edge_id} found in top 10 predictions.") else: print(f"Target edge old ID {target_edge_id} NOT found in top 10 predictions.") print("-" * 50) if __name__ == "__main__": main()