| --- |
| library_name: transformers |
| tags: |
| - pytorch |
| - custom-implementation |
| - graph-prediction |
| - edge-prediction |
| --- |
| |
| # Llama Edge Prediction Model |
|
|
| This repository contains a custom Llama 3 based model for edge prediction tasks. It predicts edge targets based on context IDs. |
|
|
| ## Model Description |
|
|
| The model corresponds to a `Llama3` architecture with the following configuration: |
| - Standard Llama 3 8B params (dim 4096, 32 layers, 32 heads, 8 KV heads) |
| - Adjusted vocab size: 9942 (custom embeddings) |
| - Intermediate size: 14336 |
|
|
| It uses a `UnifiedIdMapper` to map between original IDs (nodes/edges) and internal model IDs. |
|
|
| ## Repository Structure |
|
|
| - `configuration_llama_edge.py`: Defines `LlamaEdgeConfig` (inherits from `PretrainedConfig`). |
| - `modeling_llama_edge.py`: Defines `LlamaEdgeForCausalLM` and components (inherits from `PreTrainedModel`). |
| - `id_mapper.py`: `UnifiedIdMapper` for ID mapping logic. |
| - `inference.py`: Example script to run inference using the model and mapper. |
| - `model.safetensors`: Model weights (required). |
| - `unified_id_mapper.json`: Mapping data (required). |
|
|
| ## Usage |
|
|
| ### Loading the Model |
|
|
| You can load the model using the provided classes: |
|
|
| ```python |
| import torch |
| from configuration_llama_edge import LlamaEdgeConfig |
| from modeling_llama_edge import LlamaEdgeForCausalLM |
| from id_mapper import UnifiedIdMapper |
| |
| # Load configuration |
| config = LlamaEdgeConfig() |
| |
| # Initialize model |
| model = LlamaEdgeForCausalLM(config) |
| |
| # Load weights |
| from safetensors.torch import load_file |
| state_dict = load_file("model.safetensors") |
| model.load_state_dict(state_dict) |
| model.eval() |
| ``` |
|
|
| ### Running Inference |
|
|
| Use the `inference.py` script to run a prediction example: |
|
|
| ```bash |
| python inference.py |
| ``` |