llama3-edge / README.md
crab27's picture
Upload folder using huggingface_hub
3489e83 verified
metadata
library_name: transformers
tags:
  - pytorch
  - custom-implementation
  - graph-prediction
  - edge-prediction

Llama Edge Prediction Model

This repository contains a custom Llama 3 based model for edge prediction tasks. It predicts edge targets based on context IDs.

Model Description

The model corresponds to a Llama3 architecture with the following configuration:

  • Standard Llama 3 8B params (dim 4096, 32 layers, 32 heads, 8 KV heads)
  • Adjusted vocab size: 9942 (custom embeddings)
  • Intermediate size: 14336

It uses a UnifiedIdMapper to map between original IDs (nodes/edges) and internal model IDs.

Repository Structure

  • configuration_llama_edge.py: Defines LlamaEdgeConfig (inherits from PretrainedConfig).
  • modeling_llama_edge.py: Defines LlamaEdgeForCausalLM and components (inherits from PreTrainedModel).
  • id_mapper.py: UnifiedIdMapper for ID mapping logic.
  • inference.py: Example script to run inference using the model and mapper.
  • model.safetensors: Model weights (required).
  • unified_id_mapper.json: Mapping data (required).

Usage

Loading the Model

You can load the model using the provided classes:

import torch
from configuration_llama_edge import LlamaEdgeConfig
from modeling_llama_edge import LlamaEdgeForCausalLM
from id_mapper import UnifiedIdMapper

# Load configuration
config = LlamaEdgeConfig()

# Initialize model
model = LlamaEdgeForCausalLM(config)

# Load weights
from safetensors.torch import load_file
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()

Running Inference

Use the inference.py script to run a prediction example:

python inference.py