sadhumitha-s commited on
Commit
e73506b
·
1 Parent(s): 33a0021

feat: package model weights, SAE checkpoints, and dynamic trajectories using Git LFS

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
.github/workflows/hf_sync.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face Spaces
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ workflow_dispatch: # Allows manual syncing from the GitHub Action tab
7
+
8
+ jobs:
9
+ sync:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - name: Checkout Code
13
+ uses: actions/checkout@v4
14
+ with:
15
+ fetch-depth: 0
16
+
17
+ - name: Debug Secret Presence
18
+ run: |
19
+ if [ -z "${{ secrets.HF_TOKEN }}" ]; then
20
+ echo "❌ HF_TOKEN is empty or missing in GitHub Repository Secrets!"
21
+ exit 1
22
+ else
23
+ echo "✅ HF_TOKEN is successfully configured in GitHub Repository Secrets (len: ${#HF_TOKEN})."
24
+ fi
25
+ env:
26
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
27
+
28
+ - name: Push to Hugging Face
29
+ env:
30
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
31
+ run: |
32
+ # 1. Initialize Git LFS inside the runner
33
+ git lfs install
34
+
35
+ # 2. Configure a named remote to allow Git LFS to authenticate correctly
36
+ git remote add hf https://sadhumitha-s:$HF_TOKEN@huggingface.co/spaces/sadhumitha-s/DT-Explorer
37
+
38
+ # 3. Force push using HEAD:main to the hf remote
39
+ git push --force hf HEAD:main 2> push_err.txt || {
40
+ echo "=== HUGGING FACE PUSH ERROR LOG ==="
41
+ if [ -n "$HF_TOKEN" ]; then
42
+ sed "s/$HF_TOKEN/*****_TOKEN/g" push_err.txt
43
+ else
44
+ cat push_err.txt
45
+ fi
46
+ exit 1
47
+ }
48
+
49
+
50
+
51
+
52
+
53
+
54
+
.gitignore CHANGED
@@ -26,12 +26,14 @@ venv/
26
  ENV/
27
 
28
  # Data and Models
29
- data/
30
- models/*.pt
31
- models/*.pth
 
 
32
  *.zip
33
  *.h5
34
- *.pt
35
 
36
  # Experiment Tracking
37
  wandb/
@@ -55,6 +57,8 @@ static/
55
  .venv
56
 
57
  /PRD.md
58
- artifacts/
 
 
59
  scratch/
60
  *.log
 
26
  ENV/
27
 
28
  # Data and Models
29
+ data/*
30
+ !data/trajectories_demo.pt
31
+ models/*
32
+ !models/mini_dt.pt
33
+ *.pth
34
  *.zip
35
  *.h5
36
+
37
 
38
  # Experiment Tracking
39
  wandb/
 
57
  .venv
58
 
59
  /PRD.md
60
+ artifacts/*
61
+ !artifacts/saes/
62
+ !artifacts/saes/*
63
  scratch/
64
  *.log
Dockerfile ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies required for compilation and Gymnasium rendering
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ build-essential \
8
+ curl \
9
+ git \
10
+ libgl1-mesa-glx \
11
+ libglib2.0-0 \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy and install python dependencies first to cache this layer
15
+ COPY requirements.txt .
16
+ RUN pip install --no-cache-dir --upgrade pip && \
17
+ pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Create necessary directories
20
+ RUN mkdir -p models data artifacts/saes
21
+
22
+ # Copy trained model weights and SAE features
23
+ COPY models/mini_dt.pt ./models/mini_dt.pt
24
+ COPY artifacts/saes/ ./artifacts/saes/
25
+
26
+ # Bake in the lightweight demo trajectories as the default dataset
27
+ COPY data/trajectories_demo.pt ./data/trajectories.pt
28
+
29
+ # Copy codebase
30
+ COPY src/ ./src/
31
+
32
+ # Expose default Streamlit port
33
+ EXPOSE 8501
34
+
35
+ # Streamlit configurations for production/cloud environments
36
+ ENV STREAMLIT_SERVER_PORT=8501
37
+ ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
38
+ ENV STREAMLIT_SERVER_HEADLESS=true
39
+ ENV STREAMLIT_SERVER_ENABLE_CORS=false
40
+ ENV STREAMLIT_SERVER_ENABLE_XSRF=true
41
+ ENV STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
42
+
43
+ # Start the dashboard application
44
+ CMD ["streamlit", "run", "src/dashboard/app.py"]
Makefile CHANGED
@@ -1,4 +1,4 @@
1
- .PHONY: setup train dashboard test clean
2
 
3
  # Setup environment
4
  setup:
@@ -17,8 +17,44 @@ dashboard:
17
  test:
18
  PYTHONPATH=. pytest tests/
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Remove artifacts and cached files
21
  clean:
22
  rm -rf data/*.pt models/*.pt artifacts/saes/*.pt
23
  find . -type d -name "__pycache__" -exec rm -rf {} +
24
  find . -type d -name ".pytest_cache" -exec rm -rf {} +
 
 
1
+ .PHONY: setup train dashboard test clean deploy
2
 
3
  # Setup environment
4
  setup:
 
17
  test:
18
  PYTHONPATH=. pytest tests/
19
 
20
+ # Package and deploy to Hugging Face Spaces
21
+ deploy:
22
+ @echo "1. Slicing trajectories to data/trajectories_demo.pt (with zero-hardcoded guardrail)..."
23
+ @echo "=========================================================="
24
+ @python3 -c ' \
25
+ import torch, os; \
26
+ data = torch.load("data/trajectories.pt", map_location="cpu", weights_only=False); \
27
+ count = len(data); \
28
+ torch.save(data[:count], "data/trajectories_demo.pt"); \
29
+ size_mb = os.path.getsize("data/trajectories_demo.pt") / (1024*1024); \
30
+ if size_mb >= 9.5: \
31
+ avg_size = size_mb / count; \
32
+ count = int(9.0 / avg_size); \
33
+ while count > 0: \
34
+ demo_data = data[:count]; \
35
+ torch.save(demo_data, "data/trajectories_demo.pt"); \
36
+ size_mb = os.path.getsize("data/trajectories_demo.pt") / (1024*1024); \
37
+ if size_mb < 9.5: \
38
+ break; \
39
+ count -= 1; \
40
+ print(f"Successfully packaged {count}/{len(data)} trajectories (Size: {size_mb:.2f} MB). Safely under 10MB limit."); \
41
+ '
42
+ @echo "Done."
43
+ @echo ""
44
+ @echo "=========================================================="
45
+ @echo "2. Staging and committing deployment assets..."
46
+ @echo "=========================================================="
47
+ @git add data/trajectories_demo.pt models/mini_dt.pt artifacts/saes/ .gitignore Dockerfile docker-compose.yml Makefile scripts/deploy.sh src/dashboard/app.py .github/workflows/hf_sync.yml
48
+ @git commit -m "feat: redeploy fresh model weights and demo trajectories" || echo "No new changes to commit."
49
+ @echo ""
50
+ @echo "=========================================================="
51
+ @echo "3. Pushing changes to Hugging Face Spaces ('hf' remote)..."
52
+ @echo "=========================================================="
53
+ @git push hf main || echo "Failed to push to 'hf' remote automatically. Please verify your Space git remote is named 'hf', or manually push to your target remote (e.g. 'git push origin main')."
54
+
55
  # Remove artifacts and cached files
56
  clean:
57
  rm -rf data/*.pt models/*.pt artifacts/saes/*.pt
58
  find . -type d -name "__pycache__" -exec rm -rf {} +
59
  find . -type d -name ".pytest_cache" -exec rm -rf {} +
60
+
artifacts/saes/blocks_0_hook_resid_post_sae.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea054a846170dc71b71c8a65dd9091dfb670bf177641a63633b39281c270e658
3
+ size 1056559
artifacts/saes/blocks_1_hook_resid_post_sae.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5736510ba07d2dc7c8481a7842011da02987ae4e875c68f28d06b2b0bc0f3c6c
3
+ size 1056559
data/trajectories_demo.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b249672da20d83956eb4344a55eff64f1444a761fcd745a0873824e6f13ece4
3
+ size 8599401
docker-compose.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ dt-explorer:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ container_name: dt_explorer
9
+ ports:
10
+ - "8501:8501"
11
+ volumes:
12
+ # Mount models and data directories to reflect updates in real-time
13
+ - ./models:/app/models
14
+ - ./data:/app/data
15
+ - ./artifacts:/app/artifacts
16
+ environment:
17
+ - PYTHONUNBUFFERED=1
18
+ # Optional: set wandb or neuronpedia tokens here if needed
19
+ # - WANDB_API_KEY=${WANDB_API_KEY}
20
+ restart: unless-stopped
models/mini_dt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59f7d0f9708e6c2a22a56f369df757df57d2a0d750535ed4435f52dea89f5fcd
3
+ size 4605691
scripts/deploy.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # DT-Explorer Automated Deployment Script
5
+ # This script handles the raw bash workflow for solo researchers to update their hosted web app.
6
+
7
+ echo "=========================================================="
8
+ # 1. Slice heavy local trajectories into a lightweight demo set (zero-hardcoded dynamic scaling)
9
+ echo "1. Slicing local trajectories down to data/trajectories_demo.pt..."
10
+ echo "=========================================================="
11
+ python3 -c '
12
+ import torch, os
13
+ data = torch.load("data/trajectories.pt", map_location="cpu", weights_only=False)
14
+ count = len(data)
15
+ # Try full dataset first
16
+ torch.save(data[:count], "data/trajectories_demo.pt")
17
+ size_mb = os.path.getsize("data/trajectories_demo.pt") / (1024*1024)
18
+
19
+ if size_mb >= 9.5:
20
+ # Calculate average size per trajectory and estimate safe capacity
21
+ avg_size = size_mb / count
22
+ count = int(9.0 / avg_size) # Aim for ~9.0 MB to be safe
23
+
24
+ # Verify and make minor adjustments if needed
25
+ while count > 0:
26
+ demo_data = data[:count]
27
+ torch.save(demo_data, "data/trajectories_demo.pt")
28
+ size_mb = os.path.getsize("data/trajectories_demo.pt") / (1024*1024)
29
+ if size_mb < 9.5:
30
+ break
31
+ count -= 1
32
+
33
+ print(f"Successfully packaged {count}/{len(data)} trajectories (Size: {size_mb:.2f} MB). Safely under 10MB limit.")
34
+ '
35
+ echo "Done."
36
+ echo ""
37
+
38
+ echo "=========================================================="
39
+ # 2. Stage model weights, SAE checkpoints, and configuration files
40
+ echo "2. Staging deployment files in Git..."
41
+ echo "=========================================================="
42
+ git add data/trajectories_demo.pt models/mini_dt.pt artifacts/saes/ .gitignore Dockerfile docker-compose.yml Makefile scripts/deploy.sh src/dashboard/app.py .github/workflows/hf_sync.yml
43
+ echo "Staged."
44
+ echo ""
45
+
46
+ echo "=========================================================="
47
+ # 3. Commit changes locally
48
+ echo "3. Committing staged changes..."
49
+ echo "=========================================================="
50
+ git commit -m "feat: redeploy fresh model weights and demo trajectories" || echo "No new changes to commit."
51
+ echo ""
52
+
53
+ echo "=========================================================="
54
+ # 4. Push to GitHub (to trigger auto-sync to Hugging Face Space)
55
+ echo "4. Pushing to GitHub (origin main)..."
56
+ echo "=========================================================="
57
+ git push origin main
58
+ echo ""
59
+ echo "Deployment successful! Check your Hugging Face Space or GitHub repository actions for the build status."
src/dashboard/app.py CHANGED
@@ -25,11 +25,37 @@ st.title("DT-Explorer: Mechanistic Interpretability for DT")
25
 
26
  # Sidebar for loading model and data
27
  st.sidebar.header("Data & Model")
28
- model_path = st.sidebar.text_input("Model Path", "models/mini_dt.pt")
29
- data_path = st.sidebar.text_input("Trajectory Path", "data/trajectories.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @st.cache_data
32
  def get_data(path):
 
 
 
 
33
  if not os.path.exists(path):
34
  st.sidebar.warning(f"Data not found at {path}. Please run training script.")
35
  return None
@@ -38,13 +64,17 @@ def get_data(path):
38
 
39
  @st.cache_resource
40
  def get_model(path, state_dim):
 
 
 
 
41
  if not os.path.exists(path):
42
  st.sidebar.warning(f"Model not found at {path}. Using random init for demo.")
43
  return HookedDT.from_config(state_dim=state_dim, action_dim=7)
44
 
45
  model = HookedDT.from_config(state_dim=state_dim, action_dim=7)
46
  try:
47
- # Load state dict (usually safe for weights_only=True, but let's be explicit)
48
  model.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
49
  model.eval()
50
  except Exception as e:
 
25
 
26
  # Sidebar for loading model and data
27
  st.sidebar.header("Data & Model")
28
+
29
+ # List available models in a secure dropdown to prevent Path Traversal
30
+ models_dir = Path("models")
31
+ available_models = []
32
+ if models_dir.exists():
33
+ available_models = [str(p) for p in models_dir.glob("*.pt")]
34
+ if not available_models:
35
+ available_models = ["models/mini_dt.pt"]
36
+ model_path = st.sidebar.selectbox("Select Model Path", sorted(available_models))
37
+
38
+ # List available datasets in a secure dropdown to prevent Path Traversal
39
+ data_dir = Path("data")
40
+ available_data = []
41
+ if data_dir.exists():
42
+ available_data = [str(p) for p in data_dir.glob("*.pt")]
43
+ if not available_data:
44
+ available_data = ["data/trajectories.pt"]
45
+ data_path = st.sidebar.selectbox("Select Trajectory Path", sorted(available_data))
46
+
47
+ # Validation check to guarantee path safety (Defense-in-depth)
48
+ def is_safe_path(base_dir, path):
49
+ base_abs = Path(base_dir).resolve()
50
+ path_abs = Path(path).resolve()
51
+ return path_abs.parts[:len(base_abs.parts)] == base_abs.parts
52
 
53
  @st.cache_data
54
  def get_data(path):
55
+ if not is_safe_path("data", path):
56
+ st.sidebar.error("Access Denied: Invalid trajectory path.")
57
+ st.stop()
58
+
59
  if not os.path.exists(path):
60
  st.sidebar.warning(f"Data not found at {path}. Please run training script.")
61
  return None
 
64
 
65
  @st.cache_resource
66
  def get_model(path, state_dim):
67
+ if not is_safe_path("models", path):
68
+ st.sidebar.error("Access Denied: Invalid model path.")
69
+ st.stop()
70
+
71
  if not os.path.exists(path):
72
  st.sidebar.warning(f"Model not found at {path}. Using random init for demo.")
73
  return HookedDT.from_config(state_dim=state_dim, action_dim=7)
74
 
75
  model = HookedDT.from_config(state_dim=state_dim, action_dim=7)
76
  try:
77
+ # Load state dict (safe for weights_only=True)
78
  model.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
79
  model.eval()
80
  except Exception as e: