llama3-edge / inference.py
crab27's picture
Upload folder using huggingface_hub
4dd7afe verified
import torch
import json
from safetensors.torch import load_file
from id_mapper import UnifiedIdMapper
from modeling_llama_edge import LlamaEdgeForCausalLM
from configuration_llama_edge import LlamaEdgeConfig
class ModelWrapper:
def __init__(self, mapper_path, model_path, device="cuda"):
# Load Mapper
print(f"Loading mapper from {mapper_path}...")
self.mapper = UnifiedIdMapper.from_file(mapper_path)
# Initialize Empty Model
print("Initializing model...")
config = LlamaEdgeConfig() # Use defaults or load from file if exists
self.model = LlamaEdgeForCausalLM(config)
# Load Weights
print(f"Loading weights from {model_path}...")
state_dict = load_file(model_path, device="cpu")
self.model.load_state_dict(state_dict)
# Set device
if device == "cuda" and not torch.cuda.is_available():
print("CUDA not available, switching to CPU.")
self.device = torch.device("cpu")
elif device == "mps": # Handle MPS explicitly if requested or available
self.device = torch.device("mps")
else:
self.device = torch.device(device)
print(f"Moving model to {self.device}...")
self.model.to(self.device)
self.model.eval()
def predict(self, old_ids_context):
"""
Args:
old_ids_context: List of old IDs defining the context.
Returns:
sorted_predictions: List of (prob, old_id, label) sorted by probability descending.
"""
# 1. Convert context list of old IDs to new IDs
input_ids = []
for old_id in old_ids_context:
# We assume the input old_ids exist in the mapper
new_id, _ = self.mapper.map_old_id(old_id)
input_ids.append(new_id)
# 2. Run inference
# Create tensor on result device (batch size = 1)
model_input = torch.tensor([input_ids], dtype=torch.long, device=self.device)
with torch.no_grad():
logits = self.model(model_input)
# Get logits for the last token in the sequence
last_token_logits = logits[0, -1, :]
probs = torch.softmax(last_token_logits, dim=-1)
# 3. Sort by probability descending
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
sorted_probs = sorted_probs.tolist()
sorted_indices = sorted_indices.tolist() # These indices are the new_ids
# 4. Create result list with mapping applied
results = []
for prob, new_id in zip(sorted_probs, sorted_indices):
try:
# map_new_id returns (old_id, is_edge)
old_id, _ = self.mapper.map_new_id(new_id)
label = self.mapper.label_from_new_id(new_id)
results.append((prob, old_id, label))
except KeyError:
# Handle indices not in mapper (e.g., padding tokens)
results.append((prob, -1, "<PAD/UNK>"))
return results
def main():
# Define paths
mapper_path = "unified_id_mapper.json"
model_path = "model.safetensors"
# Check for device availability
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Using device: {device}")
# Instantiate the wrapper
wrapper = ModelWrapper(mapper_path, model_path, device=device)
# Input instance take from first line of the training data
input_ids = [108, 112, 117, 234, 421, 582, 601, 608, 940, 941, 948, 1008, 1009, 1076, 1094, 1095, 1125, 1188, 1251, 1275, 1365, 1415, 1522, 1687, 1948, 1977, 2025, 47178924, 47185647]
target_edge_id = 47182521
target_edge_label = "/people/person/place_of_birth"
predictions = wrapper.predict(input_ids)
print(f"Input old IDs: {input_ids}")
print(f"Target edge old ID: {target_edge_id}, Label: {target_edge_label}")
print("Top 10 Predictions:")
for rank, (prob, pred_old_id, pred_label) in enumerate(predictions[:10], start=1):
print(f" Rank {rank}: Old ID {pred_old_id}, Label: {pred_label}, Probability: {prob:.6f}")
# Check if the target is in top 10
top_10_old_ids = [pred_old_id for _, pred_old_id, _ in predictions[:10]]
if target_edge_id in top_10_old_ids:
print(f"Target edge old ID {target_edge_id} found in top 10 predictions.")
else:
print(f"Target edge old ID {target_edge_id} NOT found in top 10 predictions.")
print("-" * 50)
if __name__ == "__main__":
main()