File size: 1,714 Bytes
4dd7afe 3489e83 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | ---
library_name: transformers
tags:
- pytorch
- custom-implementation
- graph-prediction
- edge-prediction
---
# Llama Edge Prediction Model
This repository contains a custom Llama 3 based model for edge prediction tasks. It predicts edge targets based on context IDs.
## Model Description
The model corresponds to a `Llama3` architecture with the following configuration:
- Standard Llama 3 8B params (dim 4096, 32 layers, 32 heads, 8 KV heads)
- Adjusted vocab size: 9942 (custom embeddings)
- Intermediate size: 14336
It uses a `UnifiedIdMapper` to map between original IDs (nodes/edges) and internal model IDs.
## Repository Structure
- `configuration_llama_edge.py`: Defines `LlamaEdgeConfig` (inherits from `PretrainedConfig`).
- `modeling_llama_edge.py`: Defines `LlamaEdgeForCausalLM` and components (inherits from `PreTrainedModel`).
- `id_mapper.py`: `UnifiedIdMapper` for ID mapping logic.
- `inference.py`: Example script to run inference using the model and mapper.
- `model.safetensors`: Model weights (required).
- `unified_id_mapper.json`: Mapping data (required).
## Usage
### Loading the Model
You can load the model using the provided classes:
```python
import torch
from configuration_llama_edge import LlamaEdgeConfig
from modeling_llama_edge import LlamaEdgeForCausalLM
from id_mapper import UnifiedIdMapper
# Load configuration
config = LlamaEdgeConfig()
# Initialize model
model = LlamaEdgeForCausalLM(config)
# Load weights
from safetensors.torch import load_file
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()
```
### Running Inference
Use the `inference.py` script to run a prediction example:
```bash
python inference.py
``` |