Llama Edge Prediction Model

This repository contains a custom Llama 3 based model for edge prediction tasks. It predicts edge targets based on context IDs.

Model Description

The model corresponds to a Llama3 architecture with the following configuration:

  • Standard Llama 3 8B params (dim 4096, 32 layers, 32 heads, 8 KV heads)
  • Adjusted vocab size: 9942 (custom embeddings)
  • Intermediate size: 14336

It uses a UnifiedIdMapper to map between original IDs (nodes/edges) and internal model IDs.

Repository Structure

  • configuration_llama_edge.py: Defines LlamaEdgeConfig (inherits from PretrainedConfig).
  • modeling_llama_edge.py: Defines LlamaEdgeForCausalLM and components (inherits from PreTrainedModel).
  • id_mapper.py: UnifiedIdMapper for ID mapping logic.
  • inference.py: Example script to run inference using the model and mapper.
  • model.safetensors: Model weights (required).
  • unified_id_mapper.json: Mapping data (required).

Usage

Loading the Model

You can load the model using the provided classes:

import torch
from configuration_llama_edge import LlamaEdgeConfig
from modeling_llama_edge import LlamaEdgeForCausalLM
from id_mapper import UnifiedIdMapper

# Load configuration
config = LlamaEdgeConfig()

# Initialize model
model = LlamaEdgeForCausalLM(config)

# Load weights
from safetensors.torch import load_file
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()

Running Inference

Use the inference.py script to run a prediction example:

python inference.py
Downloads last month
17
Safetensors
Model size
7B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support