metadata
library_name: transformers
tags:
- pytorch
- custom-implementation
- graph-prediction
- edge-prediction
Llama Edge Prediction Model
This repository contains a custom Llama 3 based model for edge prediction tasks. It predicts edge targets based on context IDs.
Model Description
The model corresponds to a Llama3 architecture with the following configuration:
- Standard Llama 3 8B params (dim 4096, 32 layers, 32 heads, 8 KV heads)
- Adjusted vocab size: 9942 (custom embeddings)
- Intermediate size: 14336
It uses a UnifiedIdMapper to map between original IDs (nodes/edges) and internal model IDs.
Repository Structure
configuration_llama_edge.py: DefinesLlamaEdgeConfig(inherits fromPretrainedConfig).modeling_llama_edge.py: DefinesLlamaEdgeForCausalLMand components (inherits fromPreTrainedModel).id_mapper.py:UnifiedIdMapperfor ID mapping logic.inference.py: Example script to run inference using the model and mapper.model.safetensors: Model weights (required).unified_id_mapper.json: Mapping data (required).
Usage
Loading the Model
You can load the model using the provided classes:
import torch
from configuration_llama_edge import LlamaEdgeConfig
from modeling_llama_edge import LlamaEdgeForCausalLM
from id_mapper import UnifiedIdMapper
# Load configuration
config = LlamaEdgeConfig()
# Initialize model
model = LlamaEdgeForCausalLM(config)
# Load weights
from safetensors.torch import load_file
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()
Running Inference
Use the inference.py script to run a prediction example:
python inference.py