File size: 4,630 Bytes
4dd7afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import json
from safetensors.torch import load_file
from id_mapper import UnifiedIdMapper
from modeling_llama_edge import LlamaEdgeForCausalLM
from configuration_llama_edge import LlamaEdgeConfig

class ModelWrapper:
    def __init__(self, mapper_path, model_path, device="cuda"):
        # Load Mapper
        print(f"Loading mapper from {mapper_path}...")
        self.mapper = UnifiedIdMapper.from_file(mapper_path)

        # Initialize Empty Model
        print("Initializing model...")
        config = LlamaEdgeConfig() # Use defaults or load from file if exists
        self.model = LlamaEdgeForCausalLM(config)

        # Load Weights
        print(f"Loading weights from {model_path}...")
        state_dict = load_file(model_path, device="cpu")
        self.model.load_state_dict(state_dict)

        # Set device
        if device == "cuda" and not torch.cuda.is_available():
            print("CUDA not available, switching to CPU.")
            self.device = torch.device("cpu")
        elif device == "mps": # Handle MPS explicitly if requested or available
             self.device = torch.device("mps")
        else:
            self.device = torch.device(device)

        print(f"Moving model to {self.device}...")
        self.model.to(self.device)
        self.model.eval()

    def predict(self, old_ids_context):
        """
        Args:
            old_ids_context: List of old IDs defining the context.
        Returns:
            sorted_predictions: List of (prob, old_id, label) sorted by probability descending.
        """
        # 1. Convert context list of old IDs to new IDs
        input_ids = []
        for old_id in old_ids_context:
            # We assume the input old_ids exist in the mapper
            new_id, _ = self.mapper.map_old_id(old_id)
            input_ids.append(new_id)

        # 2. Run inference
        # Create tensor on result device (batch size = 1)
        model_input = torch.tensor([input_ids], dtype=torch.long, device=self.device)

        with torch.no_grad():
            logits = self.model(model_input)
            # Get logits for the last token in the sequence
            last_token_logits = logits[0, -1, :]
            probs = torch.softmax(last_token_logits, dim=-1)

        # 3. Sort by probability descending
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)

        sorted_probs = sorted_probs.tolist()
        sorted_indices = sorted_indices.tolist() # These indices are the new_ids

        # 4. Create result list with mapping applied
        results = []
        for prob, new_id in zip(sorted_probs, sorted_indices):
            try:
                # map_new_id returns (old_id, is_edge)
                old_id, _ = self.mapper.map_new_id(new_id)
                label = self.mapper.label_from_new_id(new_id)
                results.append((prob, old_id, label))
            except KeyError:
                # Handle indices not in mapper (e.g., padding tokens)
                results.append((prob, -1, "<PAD/UNK>"))

        return results

def main():
    # Define paths
    mapper_path = "unified_id_mapper.json"
    model_path = "model.safetensors"

    # Check for device availability
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print(f"Using device: {device}")

    # Instantiate the wrapper
    wrapper = ModelWrapper(mapper_path, model_path, device=device)

    # Input instance take from first line of the training data
    input_ids = [108, 112, 117, 234, 421, 582, 601, 608, 940, 941, 948, 1008, 1009, 1076, 1094, 1095, 1125, 1188, 1251, 1275, 1365, 1415, 1522, 1687, 1948, 1977, 2025, 47178924, 47185647]
    target_edge_id = 47182521
    target_edge_label = "/people/person/place_of_birth"

    predictions = wrapper.predict(input_ids)
    print(f"Input old IDs: {input_ids}")
    print(f"Target edge old ID: {target_edge_id}, Label: {target_edge_label}")

    print("Top 10 Predictions:")
    for rank, (prob, pred_old_id, pred_label) in enumerate(predictions[:10], start=1):
        print(f"  Rank {rank}: Old ID {pred_old_id}, Label: {pred_label}, Probability: {prob:.6f}")

    # Check if the target is in top 10
    top_10_old_ids = [pred_old_id for _, pred_old_id, _ in predictions[:10]]
    if target_edge_id in top_10_old_ids:
        print(f"Target edge old ID {target_edge_id} found in top 10 predictions.")
    else:
        print(f"Target edge old ID {target_edge_id} NOT found in top 10 predictions.")
    print("-" * 50)

if __name__ == "__main__":
    main()