crab27 commited on
Commit
4dd7afe
·
verified ·
1 Parent(s): 7ecb108

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,122 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - pytorch
5
+ - custom-implementation
6
+ - graph-prediction
7
+ - edge-prediction
8
+ ---
9
+
10
+ # Llama Edge Prediction Model
11
+
12
+ This repository contains a custom Llama 3 based model for edge prediction tasks. It predicts edge targets based on context IDs.
13
+
14
+ ## Model Description
15
+
16
+ The model corresponds to a `Llama3` architecture with the following configuration:
17
+ - Standard Llama 3 8B params (dim 4096, 32 layers, 32 heads, 8 KV heads)
18
+ - Adjusted vocab size: 9942 (custom embeddings)
19
+ - Intermediate size: 14336
20
+
21
+ It uses a `UnifiedIdMapper` to map between original IDs (nodes/edges) and internal model IDs.
22
+
23
+ ## Repository Structure
24
+
25
+ - `configuration_llama_edge.py`: Defines `LlamaEdgeConfig` (inherits from `PretrainedConfig`).
26
+ - `modeling_llama_edge.py`: Defines `LlamaEdgeForCausalLM` and components (inherits from `PreTrainedModel`).
27
+ - `id_mapper.py`: `UnifiedIdMapper` for ID mapping logic.
28
+ - `inference.py`: Example script to run inference using the model and mapper.
29
+ - `model.safetensors`: Model weights (required).
30
+ - `unified_id_mapper.json`: Mapping data (required).
31
+
32
+ ## Usage
33
+
34
+ ### Loading the Model
35
+
36
+ You can load the model using the provided classes:
37
+
38
+ ```python
39
+ import torch
40
+ from configuration_llama_edge import LlamaEdgeConfig
41
+ from modeling_llama_edge import LlamaEdgeForCausalLM
42
+ from id_mapper import UnifiedIdMapper
43
+
44
+ # Load configuration
45
+ config = LlamaEdgeConfig()
46
+
47
+ # Initialize model
48
+ model = LlamaEdgeForCausalLM(config)
49
+
50
+ # Load weights
51
+ from safetensors.torch import load_file
52
+ state_dict = load_file("model.safetensors")
53
+ model.load_state_dict(state_dict)
54
+ model.eval()
55
+ ```
56
+
57
+ ### Running Inference
58
+
59
+ Use the `inference.py` script to run a prediction example:
60
+
61
+ ```bash
62
+ python inference.py
63
+ ```
64
+
65
+ ## Requirements
66
+
67
+ Install the dependencies:
68
+
69
+ ```bash
70
+ pip install -r requirement.txt
71
+ ```
72
+
73
+ ## Note on Hugging Face Integration
74
+
75
+ To use with `AutoModel.from_pretrained(..., trust_remote_code=True)`, ensure `config.json` is present (generated from `LlamaEdgeConfig`). Use `register_for_auto_class` if uploading to the Hub.
76
+
77
+ ## How to use from Hugging Face Hub
78
+
79
+ Users can load this model directly from the Hub without cloning the repository.
80
+
81
+ ### 1. Install Dependencies
82
+
83
+ ```bash
84
+ pip install transformers torch numpy huggingface_hub
85
+ ```
86
+
87
+ ### 2. Download and Run Code
88
+
89
+ ```python
90
+ import torch
91
+ from transformers import AutoModel, AutoConfig
92
+ from huggingface_hub import hf_hub_download
93
+ import json
94
+ import sys
95
+
96
+ # 1. Load Model with trust_remote_code=True
97
+ model_id = "your-username/your-model-name"
98
+
99
+ # This loads the model and the custom configuration
100
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
101
+ model.eval()
102
+
103
+ # 2. Load the UnifiedIdMapper
104
+ # The mapper/helper files can be downloaded from the hub repo
105
+ mapper_path = hf_hub_download(repo_id=model_id, filename="unified_id_mapper.json")
106
+ # We also need the python class definition for the mapper.
107
+ # While it is not automatically imported by AutoModel, you can download it or copy it.
108
+ # If you are running this in a script where you can download files:
109
+
110
+ id_mapper_scipt = hf_hub_download(repo_id=model_id, filename="id_mapper.py")
111
+ # Import UnifiedIdMapper dynamically or just ensure the file is in path
112
+ import sys
113
+ import os
114
+ sys.path.append(os.path.dirname(id_mapper_scipt))
115
+ from id_mapper import UnifiedIdMapper
116
+
117
+ # Load mapper
118
+ mapper = UnifiedIdMapper.from_file(mapper_path)
119
+
120
+ # 3. Example Usage (Manual Inference)
121
+ # Your inference logic here, similar to inference.py
122
+ ```
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaEdgeForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_llama_edge.LlamaEdgeConfig",
7
+ "AutoModel": "modeling_llama_edge.LlamaEdgeForCausalLM"
8
+ },
9
+ "dim": 4096,
10
+ "dtype": "float32",
11
+ "ffn_dim_multiplier": 1.3,
12
+ "intermediate_size": 14336,
13
+ "max_seq_len": 8192,
14
+ "model_type": "llama_edge",
15
+ "multiple_of": 256,
16
+ "n_heads": 32,
17
+ "n_kv_heads": 8,
18
+ "n_layers": 32,
19
+ "norm_eps": 1e-05,
20
+ "rope_theta": 500000.0,
21
+ "transformers_version": "4.57.3",
22
+ "vocab_size": 9942
23
+ }
configuration_llama_edge.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import Optional
3
+
4
+ class LlamaEdgeConfig(PretrainedConfig):
5
+ model_type = "llama_edge"
6
+
7
+ def __init__(
8
+ self,
9
+ dim: int = 4096,
10
+ n_layers: int = 32,
11
+ n_heads: int = 32,
12
+ n_kv_heads: int = 8,
13
+ vocab_size: int = 9942,
14
+ multiple_of: int = 256,
15
+ ffn_dim_multiplier: Optional[float] = 1.3,
16
+ norm_eps: float = 1e-5,
17
+ rope_theta: float = 500000.0,
18
+ max_seq_len: int = 8192,
19
+ intermediate_size: int = 14336,
20
+ **kwargs,
21
+ ):
22
+ self.dim = dim
23
+ self.n_layers = n_layers
24
+ self.n_heads = n_heads
25
+ self.n_kv_heads = n_kv_heads
26
+ self.vocab_size = vocab_size
27
+ self.multiple_of = multiple_of
28
+ self.ffn_dim_multiplier = ffn_dim_multiplier
29
+ self.norm_eps = norm_eps
30
+ self.rope_theta = rope_theta
31
+ self.max_seq_len = max_seq_len
32
+ self.intermediate_size = intermediate_size
33
+
34
+ super().__init__(**kwargs)
id_mapper.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+ import json
3
+
4
+ class UnifiedIdMapper:
5
+ def __init__(self, nodes: Dict[int, str], edges: Dict[int, str]) -> None:
6
+
7
+ # since all key in JSON are str, convert them to int
8
+ nodes = {int(k): v for k, v in nodes.items()}
9
+ edges = {int(k): v for k, v in edges.items()}
10
+
11
+ self.nodes = nodes
12
+ self.edges = edges
13
+
14
+ node_mapping = {old_id: new_id for new_id, old_id in enumerate(sorted(self.nodes.keys()))}
15
+ edge_mapping = {old_id: new_id for new_id, old_id in enumerate(sorted(edges.keys()))}
16
+ shift = len(nodes)
17
+
18
+ self.old_to_new: Dict[int, Tuple[int, bool]] = {
19
+ **{old_id: (new_id, False) for old_id, new_id in node_mapping.items()},
20
+ **{old_id: (new_id + shift, True) for old_id, new_id in edge_mapping.items()},
21
+ }
22
+ # reverse mapping: new_id -> (old_id, is_edge)
23
+ self.new_to_old: Dict[int, Tuple[int, bool]] = {
24
+ new_id: (old_id, is_edge)
25
+ for old_id, (new_id, is_edge) in self.old_to_new.items()
26
+ }
27
+
28
+ # Label maps
29
+ self.old_id_to_label: Dict[int, str] = {**nodes, **edges}
30
+ self.new_id_to_label: Dict[int, str] = {
31
+ new_id: self.old_id_to_label[old_id] for old_id, (new_id, _) in self.old_to_new.items()
32
+ }
33
+
34
+ self.label_to_old_ids: Dict[str, List[Tuple[int, bool]]] = {}
35
+ self.label_to_new_ids: Dict[str, List[Tuple[int, bool]]] = {}
36
+ for old_id, (new_id, is_edge) in self.old_to_new.items():
37
+ label = self.old_id_to_label.get(old_id)
38
+ if label is None:
39
+ continue
40
+ self.label_to_old_ids.setdefault(label, []).append((old_id, is_edge))
41
+ self.label_to_new_ids.setdefault(label, []).append((new_id, is_edge))
42
+
43
+ @classmethod
44
+ def from_file(cls, mapper_path: str):
45
+ with open(mapper_path, "r") as f:
46
+ data = json.load(f)
47
+ return cls(data['nodes'], data['edges'])
48
+
49
+ def map_old_id(self, old_id: int) -> Tuple[int, bool]:
50
+ return self.old_to_new[old_id]
51
+
52
+ def map_new_id(self, new_id: int) -> Tuple[int, bool]:
53
+ return self.new_to_old[new_id]
54
+
55
+ def label_from_old_id(self, old_id: int) -> str:
56
+ return self.old_id_to_label[old_id]
57
+
58
+ def label_from_new_id(self, new_id: int) -> str:
59
+ return self.new_id_to_label[new_id]
60
+
61
+ def old_ids_from_label(self, label: str) -> List[Tuple[int, bool]]:
62
+ return self.label_to_old_ids.get(label, [])
63
+
64
+ def new_ids_from_label(self, label: str) -> List[Tuple[int, bool]]:
65
+ return self.label_to_new_ids.get(label, [])
inference.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from safetensors.torch import load_file
4
+ from id_mapper import UnifiedIdMapper
5
+ from modeling_llama_edge import LlamaEdgeForCausalLM
6
+ from configuration_llama_edge import LlamaEdgeConfig
7
+
8
+ class ModelWrapper:
9
+ def __init__(self, mapper_path, model_path, device="cuda"):
10
+ # Load Mapper
11
+ print(f"Loading mapper from {mapper_path}...")
12
+ self.mapper = UnifiedIdMapper.from_file(mapper_path)
13
+
14
+ # Initialize Empty Model
15
+ print("Initializing model...")
16
+ config = LlamaEdgeConfig() # Use defaults or load from file if exists
17
+ self.model = LlamaEdgeForCausalLM(config)
18
+
19
+ # Load Weights
20
+ print(f"Loading weights from {model_path}...")
21
+ state_dict = load_file(model_path, device="cpu")
22
+ self.model.load_state_dict(state_dict)
23
+
24
+ # Set device
25
+ if device == "cuda" and not torch.cuda.is_available():
26
+ print("CUDA not available, switching to CPU.")
27
+ self.device = torch.device("cpu")
28
+ elif device == "mps": # Handle MPS explicitly if requested or available
29
+ self.device = torch.device("mps")
30
+ else:
31
+ self.device = torch.device(device)
32
+
33
+ print(f"Moving model to {self.device}...")
34
+ self.model.to(self.device)
35
+ self.model.eval()
36
+
37
+ def predict(self, old_ids_context):
38
+ """
39
+ Args:
40
+ old_ids_context: List of old IDs defining the context.
41
+ Returns:
42
+ sorted_predictions: List of (prob, old_id, label) sorted by probability descending.
43
+ """
44
+ # 1. Convert context list of old IDs to new IDs
45
+ input_ids = []
46
+ for old_id in old_ids_context:
47
+ # We assume the input old_ids exist in the mapper
48
+ new_id, _ = self.mapper.map_old_id(old_id)
49
+ input_ids.append(new_id)
50
+
51
+ # 2. Run inference
52
+ # Create tensor on result device (batch size = 1)
53
+ model_input = torch.tensor([input_ids], dtype=torch.long, device=self.device)
54
+
55
+ with torch.no_grad():
56
+ logits = self.model(model_input)
57
+ # Get logits for the last token in the sequence
58
+ last_token_logits = logits[0, -1, :]
59
+ probs = torch.softmax(last_token_logits, dim=-1)
60
+
61
+ # 3. Sort by probability descending
62
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
63
+
64
+ sorted_probs = sorted_probs.tolist()
65
+ sorted_indices = sorted_indices.tolist() # These indices are the new_ids
66
+
67
+ # 4. Create result list with mapping applied
68
+ results = []
69
+ for prob, new_id in zip(sorted_probs, sorted_indices):
70
+ try:
71
+ # map_new_id returns (old_id, is_edge)
72
+ old_id, _ = self.mapper.map_new_id(new_id)
73
+ label = self.mapper.label_from_new_id(new_id)
74
+ results.append((prob, old_id, label))
75
+ except KeyError:
76
+ # Handle indices not in mapper (e.g., padding tokens)
77
+ results.append((prob, -1, "<PAD/UNK>"))
78
+
79
+ return results
80
+
81
+ def main():
82
+ # Define paths
83
+ mapper_path = "unified_id_mapper.json"
84
+ model_path = "model.safetensors"
85
+
86
+ # Check for device availability
87
+ if torch.cuda.is_available():
88
+ device = "cuda"
89
+ elif torch.backends.mps.is_available():
90
+ device = "mps"
91
+ else:
92
+ device = "cpu"
93
+ print(f"Using device: {device}")
94
+
95
+ # Instantiate the wrapper
96
+ wrapper = ModelWrapper(mapper_path, model_path, device=device)
97
+
98
+ # Input instance take from first line of the training data
99
+ input_ids = [108, 112, 117, 234, 421, 582, 601, 608, 940, 941, 948, 1008, 1009, 1076, 1094, 1095, 1125, 1188, 1251, 1275, 1365, 1415, 1522, 1687, 1948, 1977, 2025, 47178924, 47185647]
100
+ target_edge_id = 47182521
101
+ target_edge_label = "/people/person/place_of_birth"
102
+
103
+ predictions = wrapper.predict(input_ids)
104
+ print(f"Input old IDs: {input_ids}")
105
+ print(f"Target edge old ID: {target_edge_id}, Label: {target_edge_label}")
106
+
107
+ print("Top 10 Predictions:")
108
+ for rank, (prob, pred_old_id, pred_label) in enumerate(predictions[:10], start=1):
109
+ print(f" Rank {rank}: Old ID {pred_old_id}, Label: {pred_label}, Probability: {prob:.6f}")
110
+
111
+ # Check if the target is in top 10
112
+ top_10_old_ids = [pred_old_id for _, pred_old_id, _ in predictions[:10]]
113
+ if target_edge_id in top_10_old_ids:
114
+ print(f"Target edge old ID {target_edge_id} found in top 10 predictions.")
115
+ else:
116
+ print(f"Target edge old ID {target_edge_id} NOT found in top 10 predictions.")
117
+ print("-" * 50)
118
+
119
+ if __name__ == "__main__":
120
+ main()
model-00001-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd6efe946e1d7ffeccb33ae66d908d10a87c25a2cc93432d677b6750f6ae1a8c
3
+ size 4927788328
model-00002-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71521cd818e18c57340cb0bae75534611f6904c8c774a7bdbc28192b9aab65f8
3
+ size 4999812608
model-00003-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8567a1262dc2f7a4bf40e4b62ad192357dcda0e20c3cff337c86dfc518ba2ccd
3
+ size 4832007088
model-00004-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57bcf7c994d8943add2838392e45875a7ff3477a16f43ac763c2058a00b44071
3
+ size 4999812648
model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c962e351aa178f4ef5fc5e55a92fd56b5acd73d565aa30695f7962bff7a523e
3
+ size 4999812648
model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6abc5bfd2dd4c5f6a4d9d7ec83a170c5be89c2865d3f433cc2568f631686d6de
3
+ size 3484929560
model.safetensors.index.json ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 7061032960,
4
+ "total_size": 28244131840
5
+ },
6
+ "weight_map": {
7
+ "layers.0.attention.wk.weight": "model-00001-of-00006.safetensors",
8
+ "layers.0.attention.wo.weight": "model-00001-of-00006.safetensors",
9
+ "layers.0.attention.wq.weight": "model-00001-of-00006.safetensors",
10
+ "layers.0.attention.wv.weight": "model-00001-of-00006.safetensors",
11
+ "layers.0.attention_norm.weight": "model-00001-of-00006.safetensors",
12
+ "layers.0.feed_forward.w1.weight": "model-00001-of-00006.safetensors",
13
+ "layers.0.feed_forward.w2.weight": "model-00001-of-00006.safetensors",
14
+ "layers.0.feed_forward.w3.weight": "model-00001-of-00006.safetensors",
15
+ "layers.0.ffn_norm.weight": "model-00001-of-00006.safetensors",
16
+ "layers.1.attention.wk.weight": "model-00001-of-00006.safetensors",
17
+ "layers.1.attention.wo.weight": "model-00001-of-00006.safetensors",
18
+ "layers.1.attention.wq.weight": "model-00001-of-00006.safetensors",
19
+ "layers.1.attention.wv.weight": "model-00001-of-00006.safetensors",
20
+ "layers.1.attention_norm.weight": "model-00001-of-00006.safetensors",
21
+ "layers.1.feed_forward.w1.weight": "model-00001-of-00006.safetensors",
22
+ "layers.1.feed_forward.w2.weight": "model-00001-of-00006.safetensors",
23
+ "layers.1.feed_forward.w3.weight": "model-00001-of-00006.safetensors",
24
+ "layers.1.ffn_norm.weight": "model-00001-of-00006.safetensors",
25
+ "layers.10.attention.wk.weight": "model-00002-of-00006.safetensors",
26
+ "layers.10.attention.wo.weight": "model-00002-of-00006.safetensors",
27
+ "layers.10.attention.wq.weight": "model-00002-of-00006.safetensors",
28
+ "layers.10.attention.wv.weight": "model-00002-of-00006.safetensors",
29
+ "layers.10.attention_norm.weight": "model-00002-of-00006.safetensors",
30
+ "layers.10.feed_forward.w1.weight": "model-00002-of-00006.safetensors",
31
+ "layers.10.feed_forward.w2.weight": "model-00002-of-00006.safetensors",
32
+ "layers.10.feed_forward.w3.weight": "model-00002-of-00006.safetensors",
33
+ "layers.10.ffn_norm.weight": "model-00002-of-00006.safetensors",
34
+ "layers.11.attention.wk.weight": "model-00002-of-00006.safetensors",
35
+ "layers.11.attention.wo.weight": "model-00002-of-00006.safetensors",
36
+ "layers.11.attention.wq.weight": "model-00002-of-00006.safetensors",
37
+ "layers.11.attention.wv.weight": "model-00002-of-00006.safetensors",
38
+ "layers.11.attention_norm.weight": "model-00003-of-00006.safetensors",
39
+ "layers.11.feed_forward.w1.weight": "model-00003-of-00006.safetensors",
40
+ "layers.11.feed_forward.w2.weight": "model-00003-of-00006.safetensors",
41
+ "layers.11.feed_forward.w3.weight": "model-00003-of-00006.safetensors",
42
+ "layers.11.ffn_norm.weight": "model-00003-of-00006.safetensors",
43
+ "layers.12.attention.wk.weight": "model-00003-of-00006.safetensors",
44
+ "layers.12.attention.wo.weight": "model-00003-of-00006.safetensors",
45
+ "layers.12.attention.wq.weight": "model-00003-of-00006.safetensors",
46
+ "layers.12.attention.wv.weight": "model-00003-of-00006.safetensors",
47
+ "layers.12.attention_norm.weight": "model-00003-of-00006.safetensors",
48
+ "layers.12.feed_forward.w1.weight": "model-00003-of-00006.safetensors",
49
+ "layers.12.feed_forward.w2.weight": "model-00003-of-00006.safetensors",
50
+ "layers.12.feed_forward.w3.weight": "model-00003-of-00006.safetensors",
51
+ "layers.12.ffn_norm.weight": "model-00003-of-00006.safetensors",
52
+ "layers.13.attention.wk.weight": "model-00003-of-00006.safetensors",
53
+ "layers.13.attention.wo.weight": "model-00003-of-00006.safetensors",
54
+ "layers.13.attention.wq.weight": "model-00003-of-00006.safetensors",
55
+ "layers.13.attention.wv.weight": "model-00003-of-00006.safetensors",
56
+ "layers.13.attention_norm.weight": "model-00003-of-00006.safetensors",
57
+ "layers.13.feed_forward.w1.weight": "model-00003-of-00006.safetensors",
58
+ "layers.13.feed_forward.w2.weight": "model-00003-of-00006.safetensors",
59
+ "layers.13.feed_forward.w3.weight": "model-00003-of-00006.safetensors",
60
+ "layers.13.ffn_norm.weight": "model-00003-of-00006.safetensors",
61
+ "layers.14.attention.wk.weight": "model-00003-of-00006.safetensors",
62
+ "layers.14.attention.wo.weight": "model-00003-of-00006.safetensors",
63
+ "layers.14.attention.wq.weight": "model-00003-of-00006.safetensors",
64
+ "layers.14.attention.wv.weight": "model-00003-of-00006.safetensors",
65
+ "layers.14.attention_norm.weight": "model-00003-of-00006.safetensors",
66
+ "layers.14.feed_forward.w1.weight": "model-00003-of-00006.safetensors",
67
+ "layers.14.feed_forward.w2.weight": "model-00003-of-00006.safetensors",
68
+ "layers.14.feed_forward.w3.weight": "model-00003-of-00006.safetensors",
69
+ "layers.14.ffn_norm.weight": "model-00003-of-00006.safetensors",
70
+ "layers.15.attention.wk.weight": "model-00003-of-00006.safetensors",
71
+ "layers.15.attention.wo.weight": "model-00003-of-00006.safetensors",
72
+ "layers.15.attention.wq.weight": "model-00003-of-00006.safetensors",
73
+ "layers.15.attention.wv.weight": "model-00003-of-00006.safetensors",
74
+ "layers.15.attention_norm.weight": "model-00003-of-00006.safetensors",
75
+ "layers.15.feed_forward.w1.weight": "model-00003-of-00006.safetensors",
76
+ "layers.15.feed_forward.w2.weight": "model-00003-of-00006.safetensors",
77
+ "layers.15.feed_forward.w3.weight": "model-00003-of-00006.safetensors",
78
+ "layers.15.ffn_norm.weight": "model-00003-of-00006.safetensors",
79
+ "layers.16.attention.wk.weight": "model-00003-of-00006.safetensors",
80
+ "layers.16.attention.wo.weight": "model-00003-of-00006.safetensors",
81
+ "layers.16.attention.wq.weight": "model-00003-of-00006.safetensors",
82
+ "layers.16.attention.wv.weight": "model-00003-of-00006.safetensors",
83
+ "layers.16.attention_norm.weight": "model-00004-of-00006.safetensors",
84
+ "layers.16.feed_forward.w1.weight": "model-00003-of-00006.safetensors",
85
+ "layers.16.feed_forward.w2.weight": "model-00003-of-00006.safetensors",
86
+ "layers.16.feed_forward.w3.weight": "model-00004-of-00006.safetensors",
87
+ "layers.16.ffn_norm.weight": "model-00004-of-00006.safetensors",
88
+ "layers.17.attention.wk.weight": "model-00004-of-00006.safetensors",
89
+ "layers.17.attention.wo.weight": "model-00004-of-00006.safetensors",
90
+ "layers.17.attention.wq.weight": "model-00004-of-00006.safetensors",
91
+ "layers.17.attention.wv.weight": "model-00004-of-00006.safetensors",
92
+ "layers.17.attention_norm.weight": "model-00004-of-00006.safetensors",
93
+ "layers.17.feed_forward.w1.weight": "model-00004-of-00006.safetensors",
94
+ "layers.17.feed_forward.w2.weight": "model-00004-of-00006.safetensors",
95
+ "layers.17.feed_forward.w3.weight": "model-00004-of-00006.safetensors",
96
+ "layers.17.ffn_norm.weight": "model-00004-of-00006.safetensors",
97
+ "layers.18.attention.wk.weight": "model-00004-of-00006.safetensors",
98
+ "layers.18.attention.wo.weight": "model-00004-of-00006.safetensors",
99
+ "layers.18.attention.wq.weight": "model-00004-of-00006.safetensors",
100
+ "layers.18.attention.wv.weight": "model-00004-of-00006.safetensors",
101
+ "layers.18.attention_norm.weight": "model-00004-of-00006.safetensors",
102
+ "layers.18.feed_forward.w1.weight": "model-00004-of-00006.safetensors",
103
+ "layers.18.feed_forward.w2.weight": "model-00004-of-00006.safetensors",
104
+ "layers.18.feed_forward.w3.weight": "model-00004-of-00006.safetensors",
105
+ "layers.18.ffn_norm.weight": "model-00004-of-00006.safetensors",
106
+ "layers.19.attention.wk.weight": "model-00004-of-00006.safetensors",
107
+ "layers.19.attention.wo.weight": "model-00004-of-00006.safetensors",
108
+ "layers.19.attention.wq.weight": "model-00004-of-00006.safetensors",
109
+ "layers.19.attention.wv.weight": "model-00004-of-00006.safetensors",
110
+ "layers.19.attention_norm.weight": "model-00004-of-00006.safetensors",
111
+ "layers.19.feed_forward.w1.weight": "model-00004-of-00006.safetensors",
112
+ "layers.19.feed_forward.w2.weight": "model-00004-of-00006.safetensors",
113
+ "layers.19.feed_forward.w3.weight": "model-00004-of-00006.safetensors",
114
+ "layers.19.ffn_norm.weight": "model-00004-of-00006.safetensors",
115
+ "layers.2.attention.wk.weight": "model-00001-of-00006.safetensors",
116
+ "layers.2.attention.wo.weight": "model-00001-of-00006.safetensors",
117
+ "layers.2.attention.wq.weight": "model-00001-of-00006.safetensors",
118
+ "layers.2.attention.wv.weight": "model-00001-of-00006.safetensors",
119
+ "layers.2.attention_norm.weight": "model-00001-of-00006.safetensors",
120
+ "layers.2.feed_forward.w1.weight": "model-00001-of-00006.safetensors",
121
+ "layers.2.feed_forward.w2.weight": "model-00001-of-00006.safetensors",
122
+ "layers.2.feed_forward.w3.weight": "model-00001-of-00006.safetensors",
123
+ "layers.2.ffn_norm.weight": "model-00001-of-00006.safetensors",
124
+ "layers.20.attention.wk.weight": "model-00004-of-00006.safetensors",
125
+ "layers.20.attention.wo.weight": "model-00004-of-00006.safetensors",
126
+ "layers.20.attention.wq.weight": "model-00004-of-00006.safetensors",
127
+ "layers.20.attention.wv.weight": "model-00004-of-00006.safetensors",
128
+ "layers.20.attention_norm.weight": "model-00004-of-00006.safetensors",
129
+ "layers.20.feed_forward.w1.weight": "model-00004-of-00006.safetensors",
130
+ "layers.20.feed_forward.w2.weight": "model-00004-of-00006.safetensors",
131
+ "layers.20.feed_forward.w3.weight": "model-00004-of-00006.safetensors",
132
+ "layers.20.ffn_norm.weight": "model-00004-of-00006.safetensors",
133
+ "layers.21.attention.wk.weight": "model-00004-of-00006.safetensors",
134
+ "layers.21.attention.wo.weight": "model-00004-of-00006.safetensors",
135
+ "layers.21.attention.wq.weight": "model-00004-of-00006.safetensors",
136
+ "layers.21.attention.wv.weight": "model-00004-of-00006.safetensors",
137
+ "layers.21.attention_norm.weight": "model-00004-of-00006.safetensors",
138
+ "layers.21.feed_forward.w1.weight": "model-00004-of-00006.safetensors",
139
+ "layers.21.feed_forward.w2.weight": "model-00004-of-00006.safetensors",
140
+ "layers.21.feed_forward.w3.weight": "model-00004-of-00006.safetensors",
141
+ "layers.21.ffn_norm.weight": "model-00004-of-00006.safetensors",
142
+ "layers.22.attention.wk.weight": "model-00004-of-00006.safetensors",
143
+ "layers.22.attention.wo.weight": "model-00004-of-00006.safetensors",
144
+ "layers.22.attention.wq.weight": "model-00004-of-00006.safetensors",
145
+ "layers.22.attention.wv.weight": "model-00004-of-00006.safetensors",
146
+ "layers.22.attention_norm.weight": "model-00005-of-00006.safetensors",
147
+ "layers.22.feed_forward.w1.weight": "model-00004-of-00006.safetensors",
148
+ "layers.22.feed_forward.w2.weight": "model-00005-of-00006.safetensors",
149
+ "layers.22.feed_forward.w3.weight": "model-00005-of-00006.safetensors",
150
+ "layers.22.ffn_norm.weight": "model-00005-of-00006.safetensors",
151
+ "layers.23.attention.wk.weight": "model-00005-of-00006.safetensors",
152
+ "layers.23.attention.wo.weight": "model-00005-of-00006.safetensors",
153
+ "layers.23.attention.wq.weight": "model-00005-of-00006.safetensors",
154
+ "layers.23.attention.wv.weight": "model-00005-of-00006.safetensors",
155
+ "layers.23.attention_norm.weight": "model-00005-of-00006.safetensors",
156
+ "layers.23.feed_forward.w1.weight": "model-00005-of-00006.safetensors",
157
+ "layers.23.feed_forward.w2.weight": "model-00005-of-00006.safetensors",
158
+ "layers.23.feed_forward.w3.weight": "model-00005-of-00006.safetensors",
159
+ "layers.23.ffn_norm.weight": "model-00005-of-00006.safetensors",
160
+ "layers.24.attention.wk.weight": "model-00005-of-00006.safetensors",
161
+ "layers.24.attention.wo.weight": "model-00005-of-00006.safetensors",
162
+ "layers.24.attention.wq.weight": "model-00005-of-00006.safetensors",
163
+ "layers.24.attention.wv.weight": "model-00005-of-00006.safetensors",
164
+ "layers.24.attention_norm.weight": "model-00005-of-00006.safetensors",
165
+ "layers.24.feed_forward.w1.weight": "model-00005-of-00006.safetensors",
166
+ "layers.24.feed_forward.w2.weight": "model-00005-of-00006.safetensors",
167
+ "layers.24.feed_forward.w3.weight": "model-00005-of-00006.safetensors",
168
+ "layers.24.ffn_norm.weight": "model-00005-of-00006.safetensors",
169
+ "layers.25.attention.wk.weight": "model-00005-of-00006.safetensors",
170
+ "layers.25.attention.wo.weight": "model-00005-of-00006.safetensors",
171
+ "layers.25.attention.wq.weight": "model-00005-of-00006.safetensors",
172
+ "layers.25.attention.wv.weight": "model-00005-of-00006.safetensors",
173
+ "layers.25.attention_norm.weight": "model-00005-of-00006.safetensors",
174
+ "layers.25.feed_forward.w1.weight": "model-00005-of-00006.safetensors",
175
+ "layers.25.feed_forward.w2.weight": "model-00005-of-00006.safetensors",
176
+ "layers.25.feed_forward.w3.weight": "model-00005-of-00006.safetensors",
177
+ "layers.25.ffn_norm.weight": "model-00005-of-00006.safetensors",
178
+ "layers.26.attention.wk.weight": "model-00005-of-00006.safetensors",
179
+ "layers.26.attention.wo.weight": "model-00005-of-00006.safetensors",
180
+ "layers.26.attention.wq.weight": "model-00005-of-00006.safetensors",
181
+ "layers.26.attention.wv.weight": "model-00005-of-00006.safetensors",
182
+ "layers.26.attention_norm.weight": "model-00005-of-00006.safetensors",
183
+ "layers.26.feed_forward.w1.weight": "model-00005-of-00006.safetensors",
184
+ "layers.26.feed_forward.w2.weight": "model-00005-of-00006.safetensors",
185
+ "layers.26.feed_forward.w3.weight": "model-00005-of-00006.safetensors",
186
+ "layers.26.ffn_norm.weight": "model-00005-of-00006.safetensors",
187
+ "layers.27.attention.wk.weight": "model-00005-of-00006.safetensors",
188
+ "layers.27.attention.wo.weight": "model-00005-of-00006.safetensors",
189
+ "layers.27.attention.wq.weight": "model-00005-of-00006.safetensors",
190
+ "layers.27.attention.wv.weight": "model-00005-of-00006.safetensors",
191
+ "layers.27.attention_norm.weight": "model-00005-of-00006.safetensors",
192
+ "layers.27.feed_forward.w1.weight": "model-00005-of-00006.safetensors",
193
+ "layers.27.feed_forward.w2.weight": "model-00005-of-00006.safetensors",
194
+ "layers.27.feed_forward.w3.weight": "model-00005-of-00006.safetensors",
195
+ "layers.27.ffn_norm.weight": "model-00005-of-00006.safetensors",
196
+ "layers.28.attention.wk.weight": "model-00005-of-00006.safetensors",
197
+ "layers.28.attention.wo.weight": "model-00005-of-00006.safetensors",
198
+ "layers.28.attention.wq.weight": "model-00005-of-00006.safetensors",
199
+ "layers.28.attention.wv.weight": "model-00005-of-00006.safetensors",
200
+ "layers.28.attention_norm.weight": "model-00006-of-00006.safetensors",
201
+ "layers.28.feed_forward.w1.weight": "model-00006-of-00006.safetensors",
202
+ "layers.28.feed_forward.w2.weight": "model-00006-of-00006.safetensors",
203
+ "layers.28.feed_forward.w3.weight": "model-00006-of-00006.safetensors",
204
+ "layers.28.ffn_norm.weight": "model-00006-of-00006.safetensors",
205
+ "layers.29.attention.wk.weight": "model-00006-of-00006.safetensors",
206
+ "layers.29.attention.wo.weight": "model-00006-of-00006.safetensors",
207
+ "layers.29.attention.wq.weight": "model-00006-of-00006.safetensors",
208
+ "layers.29.attention.wv.weight": "model-00006-of-00006.safetensors",
209
+ "layers.29.attention_norm.weight": "model-00006-of-00006.safetensors",
210
+ "layers.29.feed_forward.w1.weight": "model-00006-of-00006.safetensors",
211
+ "layers.29.feed_forward.w2.weight": "model-00006-of-00006.safetensors",
212
+ "layers.29.feed_forward.w3.weight": "model-00006-of-00006.safetensors",
213
+ "layers.29.ffn_norm.weight": "model-00006-of-00006.safetensors",
214
+ "layers.3.attention.wk.weight": "model-00001-of-00006.safetensors",
215
+ "layers.3.attention.wo.weight": "model-00001-of-00006.safetensors",
216
+ "layers.3.attention.wq.weight": "model-00001-of-00006.safetensors",
217
+ "layers.3.attention.wv.weight": "model-00001-of-00006.safetensors",
218
+ "layers.3.attention_norm.weight": "model-00001-of-00006.safetensors",
219
+ "layers.3.feed_forward.w1.weight": "model-00001-of-00006.safetensors",
220
+ "layers.3.feed_forward.w2.weight": "model-00001-of-00006.safetensors",
221
+ "layers.3.feed_forward.w3.weight": "model-00001-of-00006.safetensors",
222
+ "layers.3.ffn_norm.weight": "model-00001-of-00006.safetensors",
223
+ "layers.30.attention.wk.weight": "model-00006-of-00006.safetensors",
224
+ "layers.30.attention.wo.weight": "model-00006-of-00006.safetensors",
225
+ "layers.30.attention.wq.weight": "model-00006-of-00006.safetensors",
226
+ "layers.30.attention.wv.weight": "model-00006-of-00006.safetensors",
227
+ "layers.30.attention_norm.weight": "model-00006-of-00006.safetensors",
228
+ "layers.30.feed_forward.w1.weight": "model-00006-of-00006.safetensors",
229
+ "layers.30.feed_forward.w2.weight": "model-00006-of-00006.safetensors",
230
+ "layers.30.feed_forward.w3.weight": "model-00006-of-00006.safetensors",
231
+ "layers.30.ffn_norm.weight": "model-00006-of-00006.safetensors",
232
+ "layers.31.attention.wk.weight": "model-00006-of-00006.safetensors",
233
+ "layers.31.attention.wo.weight": "model-00006-of-00006.safetensors",
234
+ "layers.31.attention.wq.weight": "model-00006-of-00006.safetensors",
235
+ "layers.31.attention.wv.weight": "model-00006-of-00006.safetensors",
236
+ "layers.31.attention_norm.weight": "model-00006-of-00006.safetensors",
237
+ "layers.31.feed_forward.w1.weight": "model-00006-of-00006.safetensors",
238
+ "layers.31.feed_forward.w2.weight": "model-00006-of-00006.safetensors",
239
+ "layers.31.feed_forward.w3.weight": "model-00006-of-00006.safetensors",
240
+ "layers.31.ffn_norm.weight": "model-00006-of-00006.safetensors",
241
+ "layers.4.attention.wk.weight": "model-00001-of-00006.safetensors",
242
+ "layers.4.attention.wo.weight": "model-00001-of-00006.safetensors",
243
+ "layers.4.attention.wq.weight": "model-00001-of-00006.safetensors",
244
+ "layers.4.attention.wv.weight": "model-00001-of-00006.safetensors",
245
+ "layers.4.attention_norm.weight": "model-00001-of-00006.safetensors",
246
+ "layers.4.feed_forward.w1.weight": "model-00001-of-00006.safetensors",
247
+ "layers.4.feed_forward.w2.weight": "model-00001-of-00006.safetensors",
248
+ "layers.4.feed_forward.w3.weight": "model-00001-of-00006.safetensors",
249
+ "layers.4.ffn_norm.weight": "model-00001-of-00006.safetensors",
250
+ "layers.5.attention.wk.weight": "model-00001-of-00006.safetensors",
251
+ "layers.5.attention.wo.weight": "model-00001-of-00006.safetensors",
252
+ "layers.5.attention.wq.weight": "model-00001-of-00006.safetensors",
253
+ "layers.5.attention.wv.weight": "model-00001-of-00006.safetensors",
254
+ "layers.5.attention_norm.weight": "model-00002-of-00006.safetensors",
255
+ "layers.5.feed_forward.w1.weight": "model-00001-of-00006.safetensors",
256
+ "layers.5.feed_forward.w2.weight": "model-00002-of-00006.safetensors",
257
+ "layers.5.feed_forward.w3.weight": "model-00002-of-00006.safetensors",
258
+ "layers.5.ffn_norm.weight": "model-00002-of-00006.safetensors",
259
+ "layers.6.attention.wk.weight": "model-00002-of-00006.safetensors",
260
+ "layers.6.attention.wo.weight": "model-00002-of-00006.safetensors",
261
+ "layers.6.attention.wq.weight": "model-00002-of-00006.safetensors",
262
+ "layers.6.attention.wv.weight": "model-00002-of-00006.safetensors",
263
+ "layers.6.attention_norm.weight": "model-00002-of-00006.safetensors",
264
+ "layers.6.feed_forward.w1.weight": "model-00002-of-00006.safetensors",
265
+ "layers.6.feed_forward.w2.weight": "model-00002-of-00006.safetensors",
266
+ "layers.6.feed_forward.w3.weight": "model-00002-of-00006.safetensors",
267
+ "layers.6.ffn_norm.weight": "model-00002-of-00006.safetensors",
268
+ "layers.7.attention.wk.weight": "model-00002-of-00006.safetensors",
269
+ "layers.7.attention.wo.weight": "model-00002-of-00006.safetensors",
270
+ "layers.7.attention.wq.weight": "model-00002-of-00006.safetensors",
271
+ "layers.7.attention.wv.weight": "model-00002-of-00006.safetensors",
272
+ "layers.7.attention_norm.weight": "model-00002-of-00006.safetensors",
273
+ "layers.7.feed_forward.w1.weight": "model-00002-of-00006.safetensors",
274
+ "layers.7.feed_forward.w2.weight": "model-00002-of-00006.safetensors",
275
+ "layers.7.feed_forward.w3.weight": "model-00002-of-00006.safetensors",
276
+ "layers.7.ffn_norm.weight": "model-00002-of-00006.safetensors",
277
+ "layers.8.attention.wk.weight": "model-00002-of-00006.safetensors",
278
+ "layers.8.attention.wo.weight": "model-00002-of-00006.safetensors",
279
+ "layers.8.attention.wq.weight": "model-00002-of-00006.safetensors",
280
+ "layers.8.attention.wv.weight": "model-00002-of-00006.safetensors",
281
+ "layers.8.attention_norm.weight": "model-00002-of-00006.safetensors",
282
+ "layers.8.feed_forward.w1.weight": "model-00002-of-00006.safetensors",
283
+ "layers.8.feed_forward.w2.weight": "model-00002-of-00006.safetensors",
284
+ "layers.8.feed_forward.w3.weight": "model-00002-of-00006.safetensors",
285
+ "layers.8.ffn_norm.weight": "model-00002-of-00006.safetensors",
286
+ "layers.9.attention.wk.weight": "model-00002-of-00006.safetensors",
287
+ "layers.9.attention.wo.weight": "model-00002-of-00006.safetensors",
288
+ "layers.9.attention.wq.weight": "model-00002-of-00006.safetensors",
289
+ "layers.9.attention.wv.weight": "model-00002-of-00006.safetensors",
290
+ "layers.9.attention_norm.weight": "model-00002-of-00006.safetensors",
291
+ "layers.9.feed_forward.w1.weight": "model-00002-of-00006.safetensors",
292
+ "layers.9.feed_forward.w2.weight": "model-00002-of-00006.safetensors",
293
+ "layers.9.feed_forward.w3.weight": "model-00002-of-00006.safetensors",
294
+ "layers.9.ffn_norm.weight": "model-00002-of-00006.safetensors",
295
+ "norm.weight": "model-00006-of-00006.safetensors",
296
+ "output.weight": "model-00006-of-00006.safetensors",
297
+ "token_embedding.weight": "model-00001-of-00006.safetensors"
298
+ }
299
+ }
modeling_llama_edge.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Optional
6
+ from transformers import PreTrainedModel
7
+ from configuration_llama_edge import LlamaEdgeConfig
8
+
9
+ class RMSNorm(nn.Module):
10
+ def __init__(self, dim: int, eps: float = 1e-6):
11
+ super().__init__()
12
+ self.eps = eps
13
+ self.weight = nn.Parameter(torch.ones(dim))
14
+
15
+ def forward(self, x):
16
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
17
+ return output * self.weight
18
+
19
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
20
+ # Precompute complex exponentials for Rotary Positional Embeddings (RoPE)
21
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
22
+ t = torch.arange(end, device=freqs.device)
23
+ freqs = torch.outer(t, freqs).float()
24
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
25
+ return freqs_cis
26
+
27
+ class FeedForward(nn.Module):
28
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]):
29
+ super().__init__()
30
+
31
+ # If the config provides a specific hidden_dim (intermediate_size), use it directly.
32
+ # Otherwise, calculate it using the standard Llama formula.
33
+ if hidden_dim is None:
34
+ hidden_dim = 4 * dim
35
+ hidden_dim = int(2 * hidden_dim / 3)
36
+ if ffn_dim_multiplier is not None:
37
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
38
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
39
+
40
+ # In Llama 3 8B, this will now be 14336
41
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
42
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
43
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
44
+
45
+ def forward(self, x):
46
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
47
+
48
+ class Attention(nn.Module):
49
+ def __init__(self, config: LlamaEdgeConfig):
50
+ super().__init__()
51
+ self.n_heads = config.n_heads
52
+ self.n_kv_heads = config.n_kv_heads
53
+ self.head_dim = config.dim // config.n_heads
54
+ self.n_rep = self.n_heads // self.n_kv_heads
55
+
56
+ self.wq = nn.Linear(config.dim, config.n_heads * self.head_dim, bias=False)
57
+ self.wk = nn.Linear(config.dim, config.n_kv_heads * self.head_dim, bias=False)
58
+ self.wv = nn.Linear(config.dim, config.n_kv_heads * self.head_dim, bias=False)
59
+ self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
60
+
61
+ def forward(self, x, freqs_cis, mask=None):
62
+ bsz, seqlen, _ = x.shape
63
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
64
+
65
+ # Reshape for multi-head attention
66
+ xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
67
+ xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
68
+ xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
69
+
70
+ # Apply RoPE
71
+ # xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
72
+
73
+ # Repeat K and V heads for GQA (if n_kv_heads < n_heads)
74
+ if self.n_rep > 1:
75
+ xk = xk.unsqueeze(3).repeat(1, 1, 1, self.n_rep, 1).reshape(bsz, seqlen, self.n_heads, self.head_dim)
76
+ xv = xv.unsqueeze(3).repeat(1, 1, 1, self.n_rep, 1).reshape(bsz, seqlen, self.n_heads, self.head_dim)
77
+
78
+ # Transpose for attention calculation: (bsz, heads, seqlen, dim)
79
+ xq = xq.transpose(1, 2)
80
+ xk = xk.transpose(1, 2)
81
+ xv = xv.transpose(1, 2)
82
+
83
+ # Scaled Dot-Product Attention
84
+ scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
85
+ # if mask is not None:
86
+ # scores = scores + mask # Apply causal mask
87
+
88
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
89
+ output = torch.matmul(scores, xv)
90
+
91
+ # Reshape back
92
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
93
+ return self.wo(output)
94
+
95
+ class TransformerBlock(nn.Module):
96
+ def __init__(self, layer_id: int, config: LlamaEdgeConfig):
97
+ super().__init__()
98
+ self.attention = Attention(config)
99
+ self.feed_forward = FeedForward(
100
+ dim=config.dim,
101
+ hidden_dim=config.intermediate_size,
102
+ multiple_of=config.multiple_of,
103
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
104
+ )
105
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
106
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
107
+
108
+ def forward(self, x, freqs_cis, mask=None):
109
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask)
110
+ out = h + self.feed_forward(self.ffn_norm(h))
111
+ return out
112
+
113
+ class LlamaEdgeForCausalLM(PreTrainedModel):
114
+ config_class = LlamaEdgeConfig
115
+
116
+ def __init__(self, config: LlamaEdgeConfig):
117
+ super().__init__(config)
118
+ self.token_embedding = nn.Embedding(config.vocab_size, config.dim)
119
+ self.layers = nn.ModuleList([TransformerBlock(i, config) for i in range(config.n_layers)])
120
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
121
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
122
+
123
+ # Precompute RoPE frequencies
124
+ self.freqs_cis = precompute_freqs_cis(
125
+ config.dim // config.n_heads, config.max_seq_len, config.rope_theta,
126
+ )
127
+
128
+
129
+ def forward(self, x):
130
+ bsz, seqlen = x.shape
131
+ freqs_cis = self.freqs_cis[:seqlen].to(x.device)
132
+
133
+ # Create causal mask
134
+ mask = torch.full((seqlen, seqlen), float("-inf"), device=x.device)
135
+ mask = torch.triu(mask, diagonal=1)
136
+
137
+ h = self.token_embedding(x)
138
+
139
+ for layer in self.layers:
140
+ h = layer(h, freqs_cis, mask)
141
+
142
+ h = self.norm(h)
143
+ logits = self.output(h)
144
+ return logits
unified_id_mapper.json ADDED
The diff for this file is too large to render. See raw diff