Upload folder using huggingface_hub
Browse files- README.md +1 -3
- checkpoints/checkpoint-100.pt +2 -2
- checkpoints/checkpoint-25.pt +2 -2
- checkpoints/checkpoint-50.pt +2 -2
- checkpoints/checkpoint-75.pt +2 -2
- convert_checkpoints.py +32 -43
- model.yaml +2 -2
- training.yaml +4 -4
README.md
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
-
# triangle-
|
| 2 |
|
| 3 |
This repository contains the final trained model and intermediate checkpoints.
|
| 4 |
|
| 5 |
- The main directory contains the fully trained model (checkpoint 0).
|
| 6 |
- The `checkpoints` directory contains all intermediate checkpoints.
|
| 7 |
|
| 8 |
-
Now updated to match tetrahedron format
|
| 9 |
-
|
|
|
|
| 1 |
+
# triangle-100k-og Checkpoints
|
| 2 |
|
| 3 |
This repository contains the final trained model and intermediate checkpoints.
|
| 4 |
|
| 5 |
- The main directory contains the fully trained model (checkpoint 0).
|
| 6 |
- The `checkpoints` directory contains all intermediate checkpoints.
|
| 7 |
|
|
|
|
|
|
checkpoints/checkpoint-100.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d171f342841db6a5fc0a1893c1d8ef5306cb446a356ff45b2ad0a61e53098278
|
| 3 |
+
size 2295597
|
checkpoints/checkpoint-25.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b30b7ea80a4a204da4284b08471c32085ae6665e37b24d9365f21cfdce1f5dd3
|
| 3 |
+
size 2295561
|
checkpoints/checkpoint-50.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:baa896287380f2fdc7e6dfeafbf74acfa1801fa5c2b9cd35e13a8b656b8804de
|
| 3 |
+
size 2295561
|
checkpoints/checkpoint-75.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2a97f4afb2cd35120cff209c44a5c9a9dfd73182a27ac332be49740a09557e7
|
| 3 |
+
size 2295561
|
convert_checkpoints.py
CHANGED
|
@@ -1,48 +1,37 @@
|
|
| 1 |
-
|
| 2 |
-
import os
|
| 3 |
-
from pathlib import Path
|
| 4 |
|
| 5 |
-
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
"""Convert nested format to direct format in place."""
|
| 10 |
-
# Load the checkpoint
|
| 11 |
-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 12 |
-
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
else:
|
| 26 |
-
|
| 27 |
-
print(f"Keeping key: '{key}'")
|
| 28 |
-
new_state_dict[key] = value
|
| 29 |
-
|
| 30 |
-
# Save back in direct format
|
| 31 |
-
torch.save(new_state_dict, checkpoint_path)
|
| 32 |
-
print(f"Updated: {checkpoint_path}")
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def convert_all_checkpoints():
|
| 36 |
-
"""Convert all checkpoint files in the current directory."""
|
| 37 |
-
checkpoints_dir = Path("checkpoints")
|
| 38 |
-
|
| 39 |
-
if not checkpoints_dir.exists():
|
| 40 |
-
print("No checkpoints directory found")
|
| 41 |
-
return
|
| 42 |
-
|
| 43 |
-
# Convert all .pt files
|
| 44 |
-
for checkpoint_file in checkpoints_dir.glob("*.pt"):
|
| 45 |
-
convert_checkpoint_format(checkpoint_file)
|
| 46 |
-
|
| 47 |
-
if __name__ == "__main__":
|
| 48 |
-
convert_all_checkpoints()
|
|
|
|
| 1 |
+
""" Open every file in the checkpoints directory and change the keys. keys are currently a nested dict. keep only keys that are in checkpoint['model'] and rename those keys such that, e.g. 'model.pos_embed.W_pos' becomes 'pos_embed.W_pos'.
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
"""
|
| 4 |
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import torch
|
| 8 |
|
| 9 |
+
checkpoints_dir = "checkpoints/"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
for file in os.listdir(checkpoints_dir):
|
| 12 |
+
if file.endswith(".pt"):
|
| 13 |
+
file_path = os.path.join(checkpoints_dir, file)
|
| 14 |
+
print(f"Processing {file}...")
|
| 15 |
+
|
| 16 |
+
# Load the checkpoint
|
| 17 |
+
checkpoint = torch.load(file_path, map_location='cpu')
|
| 18 |
+
|
| 19 |
+
# Extract model keys and rename them
|
| 20 |
+
if 'model' in checkpoint:
|
| 21 |
+
model_state_dict = checkpoint['model']
|
| 22 |
+
converted_state_dict = {}
|
| 23 |
+
|
| 24 |
+
for key, value in model_state_dict.items():
|
| 25 |
+
# Remove 'model.' prefix if it exists
|
| 26 |
+
if key.startswith('model.'):
|
| 27 |
+
new_key = key[6:] # Remove 'model.' prefix
|
| 28 |
+
else:
|
| 29 |
+
new_key = key
|
| 30 |
+
converted_state_dict[new_key] = value
|
| 31 |
+
|
| 32 |
+
# Save the converted checkpoint as a flat dictionary
|
| 33 |
+
output_path = os.path.join(checkpoints_dir, f"converted_{file}")
|
| 34 |
+
torch.save(converted_state_dict, output_path)
|
| 35 |
+
print(f"Saved converted checkpoint to {output_path}")
|
| 36 |
else:
|
| 37 |
+
print(f"Warning: No 'model' key found in {file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.yaml
CHANGED
|
@@ -3,9 +3,9 @@ implementation: transformer_lens
|
|
| 3 |
model_name: default
|
| 4 |
n_layers: '2'
|
| 5 |
model_seed: '1'
|
| 6 |
-
d_model: '
|
| 7 |
n_ctx: '1024'
|
| 8 |
-
d_head: '
|
| 9 |
n_heads: '8'
|
| 10 |
act_fn: gelu
|
| 11 |
d_vocab: '5000'
|
|
|
|
| 3 |
model_name: default
|
| 4 |
n_layers: '2'
|
| 5 |
model_seed: '1'
|
| 6 |
+
d_model: '8'
|
| 7 |
n_ctx: '1024'
|
| 8 |
+
d_head: '2'
|
| 9 |
n_heads: '8'
|
| 10 |
act_fn: gelu
|
| 11 |
d_vocab: '5000'
|
training.yaml
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
output_dir: checkpoints/triangle-
|
| 2 |
overwrite_output_dir: 'False'
|
| 3 |
do_train: 'False'
|
| 4 |
do_eval: 'False'
|
|
@@ -28,7 +28,7 @@ warmup_steps: '0'
|
|
| 28 |
log_level: warning
|
| 29 |
log_level_replica: warning
|
| 30 |
log_on_each_node: 'True'
|
| 31 |
-
logging_dir: checkpoints/triangle-
|
| 32 |
logging_strategy: IntervalStrategy.STEPS
|
| 33 |
logging_first_step: 'True'
|
| 34 |
logging_steps: '250'
|
|
@@ -64,7 +64,7 @@ eval_steps: None
|
|
| 64 |
dataloader_num_workers: '0'
|
| 65 |
dataloader_prefetch_factor: None
|
| 66 |
past_index: '-1'
|
| 67 |
-
run_name: triangle-
|
| 68 |
disable_tqdm: 'False'
|
| 69 |
remove_unused_columns: 'False'
|
| 70 |
label_names: '[''input_ids'']'
|
|
@@ -97,7 +97,7 @@ skip_memory_metrics: 'True'
|
|
| 97 |
use_legacy_prediction_loop: 'False'
|
| 98 |
push_to_hub: 'False'
|
| 99 |
resume_from_checkpoint: None
|
| 100 |
-
hub_model_id: timaeus/triangle-
|
| 101 |
hub_strategy: HubStrategy.EVERY_SAVE
|
| 102 |
hub_token: None
|
| 103 |
hub_private_repo: 'False'
|
|
|
|
| 1 |
+
output_dir: checkpoints/triangle-100k-og
|
| 2 |
overwrite_output_dir: 'False'
|
| 3 |
do_train: 'False'
|
| 4 |
do_eval: 'False'
|
|
|
|
| 28 |
log_level: warning
|
| 29 |
log_level_replica: warning
|
| 30 |
log_on_each_node: 'True'
|
| 31 |
+
logging_dir: checkpoints/triangle-100k-og/runs/Jul09_16-32-16_7be3271c880a
|
| 32 |
logging_strategy: IntervalStrategy.STEPS
|
| 33 |
logging_first_step: 'True'
|
| 34 |
logging_steps: '250'
|
|
|
|
| 64 |
dataloader_num_workers: '0'
|
| 65 |
dataloader_prefetch_factor: None
|
| 66 |
past_index: '-1'
|
| 67 |
+
run_name: triangle-100k-og
|
| 68 |
disable_tqdm: 'False'
|
| 69 |
remove_unused_columns: 'False'
|
| 70 |
label_names: '[''input_ids'']'
|
|
|
|
| 97 |
use_legacy_prediction_loop: 'False'
|
| 98 |
push_to_hub: 'False'
|
| 99 |
resume_from_checkpoint: None
|
| 100 |
+
hub_model_id: timaeus/triangle-100k-og
|
| 101 |
hub_strategy: HubStrategy.EVERY_SAVE
|
| 102 |
hub_token: None
|
| 103 |
hub_private_repo: 'False'
|