| import torch |
| import json |
| from safetensors.torch import load_file |
| from id_mapper import UnifiedIdMapper |
| from modeling_llama_edge import LlamaEdgeForCausalLM |
| from configuration_llama_edge import LlamaEdgeConfig |
|
|
| class ModelWrapper: |
| def __init__(self, mapper_path, model_path, device="cuda"): |
| |
| print(f"Loading mapper from {mapper_path}...") |
| self.mapper = UnifiedIdMapper.from_file(mapper_path) |
|
|
| |
| print("Initializing model...") |
| config = LlamaEdgeConfig() |
| self.model = LlamaEdgeForCausalLM(config) |
|
|
| |
| print(f"Loading weights from {model_path}...") |
| state_dict = load_file(model_path, device="cpu") |
| self.model.load_state_dict(state_dict) |
|
|
| |
| if device == "cuda" and not torch.cuda.is_available(): |
| print("CUDA not available, switching to CPU.") |
| self.device = torch.device("cpu") |
| elif device == "mps": |
| self.device = torch.device("mps") |
| else: |
| self.device = torch.device(device) |
|
|
| print(f"Moving model to {self.device}...") |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def predict(self, old_ids_context): |
| """ |
| Args: |
| old_ids_context: List of old IDs defining the context. |
| Returns: |
| sorted_predictions: List of (prob, old_id, label) sorted by probability descending. |
| """ |
| |
| input_ids = [] |
| for old_id in old_ids_context: |
| |
| new_id, _ = self.mapper.map_old_id(old_id) |
| input_ids.append(new_id) |
|
|
| |
| |
| model_input = torch.tensor([input_ids], dtype=torch.long, device=self.device) |
|
|
| with torch.no_grad(): |
| logits = self.model(model_input) |
| |
| last_token_logits = logits[0, -1, :] |
| probs = torch.softmax(last_token_logits, dim=-1) |
|
|
| |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
|
| sorted_probs = sorted_probs.tolist() |
| sorted_indices = sorted_indices.tolist() |
|
|
| |
| results = [] |
| for prob, new_id in zip(sorted_probs, sorted_indices): |
| try: |
| |
| old_id, _ = self.mapper.map_new_id(new_id) |
| label = self.mapper.label_from_new_id(new_id) |
| results.append((prob, old_id, label)) |
| except KeyError: |
| |
| results.append((prob, -1, "<PAD/UNK>")) |
|
|
| return results |
|
|
| def main(): |
| |
| mapper_path = "unified_id_mapper.json" |
| model_path = "model.safetensors" |
|
|
| |
| if torch.cuda.is_available(): |
| device = "cuda" |
| elif torch.backends.mps.is_available(): |
| device = "mps" |
| else: |
| device = "cpu" |
| print(f"Using device: {device}") |
|
|
| |
| wrapper = ModelWrapper(mapper_path, model_path, device=device) |
|
|
| |
| input_ids = [108, 112, 117, 234, 421, 582, 601, 608, 940, 941, 948, 1008, 1009, 1076, 1094, 1095, 1125, 1188, 1251, 1275, 1365, 1415, 1522, 1687, 1948, 1977, 2025, 47178924, 47185647] |
| target_edge_id = 47182521 |
| target_edge_label = "/people/person/place_of_birth" |
|
|
| predictions = wrapper.predict(input_ids) |
| print(f"Input old IDs: {input_ids}") |
| print(f"Target edge old ID: {target_edge_id}, Label: {target_edge_label}") |
|
|
| print("Top 10 Predictions:") |
| for rank, (prob, pred_old_id, pred_label) in enumerate(predictions[:10], start=1): |
| print(f" Rank {rank}: Old ID {pred_old_id}, Label: {pred_label}, Probability: {prob:.6f}") |
|
|
| |
| top_10_old_ids = [pred_old_id for _, pred_old_id, _ in predictions[:10]] |
| if target_edge_id in top_10_old_ids: |
| print(f"Target edge old ID {target_edge_id} found in top 10 predictions.") |
| else: |
| print(f"Target edge old ID {target_edge_id} NOT found in top 10 predictions.") |
| print("-" * 50) |
|
|
| if __name__ == "__main__": |
| main() |
|
|