algo2217 commited on
Commit
524eab3
·
verified ·
1 Parent(s): f73efa3

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,9 +1,7 @@
1
- # triangle-5k-og Checkpoints
2
 
3
  This repository contains the final trained model and intermediate checkpoints.
4
 
5
  - The main directory contains the fully trained model (checkpoint 0).
6
  - The `checkpoints` directory contains all intermediate checkpoints.
7
 
8
- Now updated to match tetrahedron format
9
-
 
1
+ # triangle-100k-og Checkpoints
2
 
3
  This repository contains the final trained model and intermediate checkpoints.
4
 
5
  - The main directory contains the fully trained model (checkpoint 0).
6
  - The `checkpoints` directory contains all intermediate checkpoints.
7
 
 
 
checkpoints/checkpoint-100.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3fafa99fb361dae2d224a1307eb27fd7362cfb29144c5b369bf9dae370563080
3
- size 2478085
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d171f342841db6a5fc0a1893c1d8ef5306cb446a356ff45b2ad0a61e53098278
3
+ size 2295597
checkpoints/checkpoint-25.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:200f2670bbfd05c752c6af8e5e3835b47ff11e4c7286203f0bbfaa0c6fbead11
3
- size 2478049
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b30b7ea80a4a204da4284b08471c32085ae6665e37b24d9365f21cfdce1f5dd3
3
+ size 2295561
checkpoints/checkpoint-50.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:beb0e997530cf6327992063dab60027b1b74c57fcb2cfaefdfdbabcd7d1068af
3
- size 2478049
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:baa896287380f2fdc7e6dfeafbf74acfa1801fa5c2b9cd35e13a8b656b8804de
3
+ size 2295561
checkpoints/checkpoint-75.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f4a79008dd26635ed3276e069048391ff2962fa834d33e3649c9d2653fe5893a
3
- size 2478049
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2a97f4afb2cd35120cff209c44a5c9a9dfd73182a27ac332be49740a09557e7
3
+ size 2295561
convert_checkpoints.py CHANGED
@@ -1,48 +1,37 @@
1
- # convert_and_replace.py
2
- import os
3
- from pathlib import Path
4
 
5
- import torch
6
 
 
 
 
7
 
8
- def convert_checkpoint_format(checkpoint_path):
9
- """Convert nested format to direct format in place."""
10
- # Load the checkpoint
11
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
12
-
13
 
14
- model_state_dict = checkpoint
15
- print(f"Converting {checkpoint_path}: nested -> direct format")
16
-
17
- # Create new state dict with flattened keys
18
- new_state_dict = {}
19
- for key, value in model_state_dict.items():
20
- if key.startswith('model.'):
21
- # Remove 'model.' prefix
22
- new_key = key[6:] # Remove 'model.' (6 characters)
23
- print(f"Converting key: '{key}' -> '{new_key}'")
24
- new_state_dict[new_key] = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  else:
26
- # Keep keys that don't start with 'model.'
27
- print(f"Keeping key: '{key}'")
28
- new_state_dict[key] = value
29
-
30
- # Save back in direct format
31
- torch.save(new_state_dict, checkpoint_path)
32
- print(f"Updated: {checkpoint_path}")
33
-
34
-
35
- def convert_all_checkpoints():
36
- """Convert all checkpoint files in the current directory."""
37
- checkpoints_dir = Path("checkpoints")
38
-
39
- if not checkpoints_dir.exists():
40
- print("No checkpoints directory found")
41
- return
42
-
43
- # Convert all .pt files
44
- for checkpoint_file in checkpoints_dir.glob("*.pt"):
45
- convert_checkpoint_format(checkpoint_file)
46
-
47
- if __name__ == "__main__":
48
- convert_all_checkpoints()
 
1
+ """ Open every file in the checkpoints directory and change the keys. keys are currently a nested dict. keep only keys that are in checkpoint['model'] and rename those keys such that, e.g. 'model.pos_embed.W_pos' becomes 'pos_embed.W_pos'.
 
 
2
 
3
+ """
4
 
5
+ import os
6
+ import json
7
+ import torch
8
 
9
+ checkpoints_dir = "checkpoints/"
 
 
 
 
10
 
11
+ for file in os.listdir(checkpoints_dir):
12
+ if file.endswith(".pt"):
13
+ file_path = os.path.join(checkpoints_dir, file)
14
+ print(f"Processing {file}...")
15
+
16
+ # Load the checkpoint
17
+ checkpoint = torch.load(file_path, map_location='cpu')
18
+
19
+ # Extract model keys and rename them
20
+ if 'model' in checkpoint:
21
+ model_state_dict = checkpoint['model']
22
+ converted_state_dict = {}
23
+
24
+ for key, value in model_state_dict.items():
25
+ # Remove 'model.' prefix if it exists
26
+ if key.startswith('model.'):
27
+ new_key = key[6:] # Remove 'model.' prefix
28
+ else:
29
+ new_key = key
30
+ converted_state_dict[new_key] = value
31
+
32
+ # Save the converted checkpoint as a flat dictionary
33
+ output_path = os.path.join(checkpoints_dir, f"converted_{file}")
34
+ torch.save(converted_state_dict, output_path)
35
+ print(f"Saved converted checkpoint to {output_path}")
36
  else:
37
+ print(f"Warning: No 'model' key found in {file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.yaml CHANGED
@@ -3,9 +3,9 @@ implementation: transformer_lens
3
  model_name: default
4
  n_layers: '2'
5
  model_seed: '1'
6
- d_model: '16'
7
  n_ctx: '1024'
8
- d_head: '4'
9
  n_heads: '8'
10
  act_fn: gelu
11
  d_vocab: '5000'
 
3
  model_name: default
4
  n_layers: '2'
5
  model_seed: '1'
6
+ d_model: '8'
7
  n_ctx: '1024'
8
+ d_head: '2'
9
  n_heads: '8'
10
  act_fn: gelu
11
  d_vocab: '5000'
training.yaml CHANGED
@@ -1,4 +1,4 @@
1
- output_dir: checkpoints/triangle-5k-og
2
  overwrite_output_dir: 'False'
3
  do_train: 'False'
4
  do_eval: 'False'
@@ -28,7 +28,7 @@ warmup_steps: '0'
28
  log_level: warning
29
  log_level_replica: warning
30
  log_on_each_node: 'True'
31
- logging_dir: checkpoints/triangle-5k-og/runs/Jul03_18-23-05_842bf34089c7
32
  logging_strategy: IntervalStrategy.STEPS
33
  logging_first_step: 'True'
34
  logging_steps: '250'
@@ -64,7 +64,7 @@ eval_steps: None
64
  dataloader_num_workers: '0'
65
  dataloader_prefetch_factor: None
66
  past_index: '-1'
67
- run_name: triangle-5k-og
68
  disable_tqdm: 'False'
69
  remove_unused_columns: 'False'
70
  label_names: '[''input_ids'']'
@@ -97,7 +97,7 @@ skip_memory_metrics: 'True'
97
  use_legacy_prediction_loop: 'False'
98
  push_to_hub: 'False'
99
  resume_from_checkpoint: None
100
- hub_model_id: timaeus/triangle-5k-og
101
  hub_strategy: HubStrategy.EVERY_SAVE
102
  hub_token: None
103
  hub_private_repo: 'False'
 
1
+ output_dir: checkpoints/triangle-100k-og
2
  overwrite_output_dir: 'False'
3
  do_train: 'False'
4
  do_eval: 'False'
 
28
  log_level: warning
29
  log_level_replica: warning
30
  log_on_each_node: 'True'
31
+ logging_dir: checkpoints/triangle-100k-og/runs/Jul09_16-32-16_7be3271c880a
32
  logging_strategy: IntervalStrategy.STEPS
33
  logging_first_step: 'True'
34
  logging_steps: '250'
 
64
  dataloader_num_workers: '0'
65
  dataloader_prefetch_factor: None
66
  past_index: '-1'
67
+ run_name: triangle-100k-og
68
  disable_tqdm: 'False'
69
  remove_unused_columns: 'False'
70
  label_names: '[''input_ids'']'
 
97
  use_legacy_prediction_loop: 'False'
98
  push_to_hub: 'False'
99
  resume_from_checkpoint: None
100
+ hub_model_id: timaeus/triangle-100k-og
101
  hub_strategy: HubStrategy.EVERY_SAVE
102
  hub_token: None
103
  hub_private_repo: 'False'