Spaces:
Sleeping
Sleeping
Commit ·
e73506b
1
Parent(s): 33a0021
feat: package model weights, SAE checkpoints, and dynamic trajectories using Git LFS
Browse files- .gitattributes +1 -0
- .github/workflows/hf_sync.yml +54 -0
- .gitignore +9 -5
- Dockerfile +44 -0
- Makefile +37 -1
- artifacts/saes/blocks_0_hook_resid_post_sae.pt +3 -0
- artifacts/saes/blocks_1_hook_resid_post_sae.pt +3 -0
- data/trajectories_demo.pt +3 -0
- docker-compose.yml +20 -0
- models/mini_dt.pt +3 -0
- scripts/deploy.sh +59 -0
- src/dashboard/app.py +33 -3
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/hf_sync.yml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync to Hugging Face Spaces
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
workflow_dispatch: # Allows manual syncing from the GitHub Action tab
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
sync:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
steps:
|
| 12 |
+
- name: Checkout Code
|
| 13 |
+
uses: actions/checkout@v4
|
| 14 |
+
with:
|
| 15 |
+
fetch-depth: 0
|
| 16 |
+
|
| 17 |
+
- name: Debug Secret Presence
|
| 18 |
+
run: |
|
| 19 |
+
if [ -z "${{ secrets.HF_TOKEN }}" ]; then
|
| 20 |
+
echo "❌ HF_TOKEN is empty or missing in GitHub Repository Secrets!"
|
| 21 |
+
exit 1
|
| 22 |
+
else
|
| 23 |
+
echo "✅ HF_TOKEN is successfully configured in GitHub Repository Secrets (len: ${#HF_TOKEN})."
|
| 24 |
+
fi
|
| 25 |
+
env:
|
| 26 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 27 |
+
|
| 28 |
+
- name: Push to Hugging Face
|
| 29 |
+
env:
|
| 30 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 31 |
+
run: |
|
| 32 |
+
# 1. Initialize Git LFS inside the runner
|
| 33 |
+
git lfs install
|
| 34 |
+
|
| 35 |
+
# 2. Configure a named remote to allow Git LFS to authenticate correctly
|
| 36 |
+
git remote add hf https://sadhumitha-s:$HF_TOKEN@huggingface.co/spaces/sadhumitha-s/DT-Explorer
|
| 37 |
+
|
| 38 |
+
# 3. Force push using HEAD:main to the hf remote
|
| 39 |
+
git push --force hf HEAD:main 2> push_err.txt || {
|
| 40 |
+
echo "=== HUGGING FACE PUSH ERROR LOG ==="
|
| 41 |
+
if [ -n "$HF_TOKEN" ]; then
|
| 42 |
+
sed "s/$HF_TOKEN/*****_TOKEN/g" push_err.txt
|
| 43 |
+
else
|
| 44 |
+
cat push_err.txt
|
| 45 |
+
fi
|
| 46 |
+
exit 1
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
.gitignore
CHANGED
|
@@ -26,12 +26,14 @@ venv/
|
|
| 26 |
ENV/
|
| 27 |
|
| 28 |
# Data and Models
|
| 29 |
-
data/
|
| 30 |
-
|
| 31 |
-
models/*
|
|
|
|
|
|
|
| 32 |
*.zip
|
| 33 |
*.h5
|
| 34 |
-
|
| 35 |
|
| 36 |
# Experiment Tracking
|
| 37 |
wandb/
|
|
@@ -55,6 +57,8 @@ static/
|
|
| 55 |
.venv
|
| 56 |
|
| 57 |
/PRD.md
|
| 58 |
-
artifacts/
|
|
|
|
|
|
|
| 59 |
scratch/
|
| 60 |
*.log
|
|
|
|
| 26 |
ENV/
|
| 27 |
|
| 28 |
# Data and Models
|
| 29 |
+
data/*
|
| 30 |
+
!data/trajectories_demo.pt
|
| 31 |
+
models/*
|
| 32 |
+
!models/mini_dt.pt
|
| 33 |
+
*.pth
|
| 34 |
*.zip
|
| 35 |
*.h5
|
| 36 |
+
|
| 37 |
|
| 38 |
# Experiment Tracking
|
| 39 |
wandb/
|
|
|
|
| 57 |
.venv
|
| 58 |
|
| 59 |
/PRD.md
|
| 60 |
+
artifacts/*
|
| 61 |
+
!artifacts/saes/
|
| 62 |
+
!artifacts/saes/*
|
| 63 |
scratch/
|
| 64 |
*.log
|
Dockerfile
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies required for compilation and Gymnasium rendering
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
build-essential \
|
| 8 |
+
curl \
|
| 9 |
+
git \
|
| 10 |
+
libgl1-mesa-glx \
|
| 11 |
+
libglib2.0-0 \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
# Copy and install python dependencies first to cache this layer
|
| 15 |
+
COPY requirements.txt .
|
| 16 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 17 |
+
pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Create necessary directories
|
| 20 |
+
RUN mkdir -p models data artifacts/saes
|
| 21 |
+
|
| 22 |
+
# Copy trained model weights and SAE features
|
| 23 |
+
COPY models/mini_dt.pt ./models/mini_dt.pt
|
| 24 |
+
COPY artifacts/saes/ ./artifacts/saes/
|
| 25 |
+
|
| 26 |
+
# Bake in the lightweight demo trajectories as the default dataset
|
| 27 |
+
COPY data/trajectories_demo.pt ./data/trajectories.pt
|
| 28 |
+
|
| 29 |
+
# Copy codebase
|
| 30 |
+
COPY src/ ./src/
|
| 31 |
+
|
| 32 |
+
# Expose default Streamlit port
|
| 33 |
+
EXPOSE 8501
|
| 34 |
+
|
| 35 |
+
# Streamlit configurations for production/cloud environments
|
| 36 |
+
ENV STREAMLIT_SERVER_PORT=8501
|
| 37 |
+
ENV STREAMLIT_SERVER_ADDRESS=0.0.0.0
|
| 38 |
+
ENV STREAMLIT_SERVER_HEADLESS=true
|
| 39 |
+
ENV STREAMLIT_SERVER_ENABLE_CORS=false
|
| 40 |
+
ENV STREAMLIT_SERVER_ENABLE_XSRF=true
|
| 41 |
+
ENV STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
|
| 42 |
+
|
| 43 |
+
# Start the dashboard application
|
| 44 |
+
CMD ["streamlit", "run", "src/dashboard/app.py"]
|
Makefile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
.PHONY: setup train dashboard test clean
|
| 2 |
|
| 3 |
# Setup environment
|
| 4 |
setup:
|
|
@@ -17,8 +17,44 @@ dashboard:
|
|
| 17 |
test:
|
| 18 |
PYTHONPATH=. pytest tests/
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# Remove artifacts and cached files
|
| 21 |
clean:
|
| 22 |
rm -rf data/*.pt models/*.pt artifacts/saes/*.pt
|
| 23 |
find . -type d -name "__pycache__" -exec rm -rf {} +
|
| 24 |
find . -type d -name ".pytest_cache" -exec rm -rf {} +
|
|
|
|
|
|
| 1 |
+
.PHONY: setup train dashboard test clean deploy
|
| 2 |
|
| 3 |
# Setup environment
|
| 4 |
setup:
|
|
|
|
| 17 |
test:
|
| 18 |
PYTHONPATH=. pytest tests/
|
| 19 |
|
| 20 |
+
# Package and deploy to Hugging Face Spaces
|
| 21 |
+
deploy:
|
| 22 |
+
@echo "1. Slicing trajectories to data/trajectories_demo.pt (with zero-hardcoded guardrail)..."
|
| 23 |
+
@echo "=========================================================="
|
| 24 |
+
@python3 -c ' \
|
| 25 |
+
import torch, os; \
|
| 26 |
+
data = torch.load("data/trajectories.pt", map_location="cpu", weights_only=False); \
|
| 27 |
+
count = len(data); \
|
| 28 |
+
torch.save(data[:count], "data/trajectories_demo.pt"); \
|
| 29 |
+
size_mb = os.path.getsize("data/trajectories_demo.pt") / (1024*1024); \
|
| 30 |
+
if size_mb >= 9.5: \
|
| 31 |
+
avg_size = size_mb / count; \
|
| 32 |
+
count = int(9.0 / avg_size); \
|
| 33 |
+
while count > 0: \
|
| 34 |
+
demo_data = data[:count]; \
|
| 35 |
+
torch.save(demo_data, "data/trajectories_demo.pt"); \
|
| 36 |
+
size_mb = os.path.getsize("data/trajectories_demo.pt") / (1024*1024); \
|
| 37 |
+
if size_mb < 9.5: \
|
| 38 |
+
break; \
|
| 39 |
+
count -= 1; \
|
| 40 |
+
print(f"Successfully packaged {count}/{len(data)} trajectories (Size: {size_mb:.2f} MB). Safely under 10MB limit."); \
|
| 41 |
+
'
|
| 42 |
+
@echo "Done."
|
| 43 |
+
@echo ""
|
| 44 |
+
@echo "=========================================================="
|
| 45 |
+
@echo "2. Staging and committing deployment assets..."
|
| 46 |
+
@echo "=========================================================="
|
| 47 |
+
@git add data/trajectories_demo.pt models/mini_dt.pt artifacts/saes/ .gitignore Dockerfile docker-compose.yml Makefile scripts/deploy.sh src/dashboard/app.py .github/workflows/hf_sync.yml
|
| 48 |
+
@git commit -m "feat: redeploy fresh model weights and demo trajectories" || echo "No new changes to commit."
|
| 49 |
+
@echo ""
|
| 50 |
+
@echo "=========================================================="
|
| 51 |
+
@echo "3. Pushing changes to Hugging Face Spaces ('hf' remote)..."
|
| 52 |
+
@echo "=========================================================="
|
| 53 |
+
@git push hf main || echo "Failed to push to 'hf' remote automatically. Please verify your Space git remote is named 'hf', or manually push to your target remote (e.g. 'git push origin main')."
|
| 54 |
+
|
| 55 |
# Remove artifacts and cached files
|
| 56 |
clean:
|
| 57 |
rm -rf data/*.pt models/*.pt artifacts/saes/*.pt
|
| 58 |
find . -type d -name "__pycache__" -exec rm -rf {} +
|
| 59 |
find . -type d -name ".pytest_cache" -exec rm -rf {} +
|
| 60 |
+
|
artifacts/saes/blocks_0_hook_resid_post_sae.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea054a846170dc71b71c8a65dd9091dfb670bf177641a63633b39281c270e658
|
| 3 |
+
size 1056559
|
artifacts/saes/blocks_1_hook_resid_post_sae.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5736510ba07d2dc7c8481a7842011da02987ae4e875c68f28d06b2b0bc0f3c6c
|
| 3 |
+
size 1056559
|
data/trajectories_demo.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b249672da20d83956eb4344a55eff64f1444a761fcd745a0873824e6f13ece4
|
| 3 |
+
size 8599401
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
dt-explorer:
|
| 5 |
+
build:
|
| 6 |
+
context: .
|
| 7 |
+
dockerfile: Dockerfile
|
| 8 |
+
container_name: dt_explorer
|
| 9 |
+
ports:
|
| 10 |
+
- "8501:8501"
|
| 11 |
+
volumes:
|
| 12 |
+
# Mount models and data directories to reflect updates in real-time
|
| 13 |
+
- ./models:/app/models
|
| 14 |
+
- ./data:/app/data
|
| 15 |
+
- ./artifacts:/app/artifacts
|
| 16 |
+
environment:
|
| 17 |
+
- PYTHONUNBUFFERED=1
|
| 18 |
+
# Optional: set wandb or neuronpedia tokens here if needed
|
| 19 |
+
# - WANDB_API_KEY=${WANDB_API_KEY}
|
| 20 |
+
restart: unless-stopped
|
models/mini_dt.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59f7d0f9708e6c2a22a56f369df757df57d2a0d750535ed4435f52dea89f5fcd
|
| 3 |
+
size 4605691
|
scripts/deploy.sh
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
# DT-Explorer Automated Deployment Script
|
| 5 |
+
# This script handles the raw bash workflow for solo researchers to update their hosted web app.
|
| 6 |
+
|
| 7 |
+
echo "=========================================================="
|
| 8 |
+
# 1. Slice heavy local trajectories into a lightweight demo set (zero-hardcoded dynamic scaling)
|
| 9 |
+
echo "1. Slicing local trajectories down to data/trajectories_demo.pt..."
|
| 10 |
+
echo "=========================================================="
|
| 11 |
+
python3 -c '
|
| 12 |
+
import torch, os
|
| 13 |
+
data = torch.load("data/trajectories.pt", map_location="cpu", weights_only=False)
|
| 14 |
+
count = len(data)
|
| 15 |
+
# Try full dataset first
|
| 16 |
+
torch.save(data[:count], "data/trajectories_demo.pt")
|
| 17 |
+
size_mb = os.path.getsize("data/trajectories_demo.pt") / (1024*1024)
|
| 18 |
+
|
| 19 |
+
if size_mb >= 9.5:
|
| 20 |
+
# Calculate average size per trajectory and estimate safe capacity
|
| 21 |
+
avg_size = size_mb / count
|
| 22 |
+
count = int(9.0 / avg_size) # Aim for ~9.0 MB to be safe
|
| 23 |
+
|
| 24 |
+
# Verify and make minor adjustments if needed
|
| 25 |
+
while count > 0:
|
| 26 |
+
demo_data = data[:count]
|
| 27 |
+
torch.save(demo_data, "data/trajectories_demo.pt")
|
| 28 |
+
size_mb = os.path.getsize("data/trajectories_demo.pt") / (1024*1024)
|
| 29 |
+
if size_mb < 9.5:
|
| 30 |
+
break
|
| 31 |
+
count -= 1
|
| 32 |
+
|
| 33 |
+
print(f"Successfully packaged {count}/{len(data)} trajectories (Size: {size_mb:.2f} MB). Safely under 10MB limit.")
|
| 34 |
+
'
|
| 35 |
+
echo "Done."
|
| 36 |
+
echo ""
|
| 37 |
+
|
| 38 |
+
echo "=========================================================="
|
| 39 |
+
# 2. Stage model weights, SAE checkpoints, and configuration files
|
| 40 |
+
echo "2. Staging deployment files in Git..."
|
| 41 |
+
echo "=========================================================="
|
| 42 |
+
git add data/trajectories_demo.pt models/mini_dt.pt artifacts/saes/ .gitignore Dockerfile docker-compose.yml Makefile scripts/deploy.sh src/dashboard/app.py .github/workflows/hf_sync.yml
|
| 43 |
+
echo "Staged."
|
| 44 |
+
echo ""
|
| 45 |
+
|
| 46 |
+
echo "=========================================================="
|
| 47 |
+
# 3. Commit changes locally
|
| 48 |
+
echo "3. Committing staged changes..."
|
| 49 |
+
echo "=========================================================="
|
| 50 |
+
git commit -m "feat: redeploy fresh model weights and demo trajectories" || echo "No new changes to commit."
|
| 51 |
+
echo ""
|
| 52 |
+
|
| 53 |
+
echo "=========================================================="
|
| 54 |
+
# 4. Push to GitHub (to trigger auto-sync to Hugging Face Space)
|
| 55 |
+
echo "4. Pushing to GitHub (origin main)..."
|
| 56 |
+
echo "=========================================================="
|
| 57 |
+
git push origin main
|
| 58 |
+
echo ""
|
| 59 |
+
echo "Deployment successful! Check your Hugging Face Space or GitHub repository actions for the build status."
|
src/dashboard/app.py
CHANGED
|
@@ -25,11 +25,37 @@ st.title("DT-Explorer: Mechanistic Interpretability for DT")
|
|
| 25 |
|
| 26 |
# Sidebar for loading model and data
|
| 27 |
st.sidebar.header("Data & Model")
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
@st.cache_data
|
| 32 |
def get_data(path):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
if not os.path.exists(path):
|
| 34 |
st.sidebar.warning(f"Data not found at {path}. Please run training script.")
|
| 35 |
return None
|
|
@@ -38,13 +64,17 @@ def get_data(path):
|
|
| 38 |
|
| 39 |
@st.cache_resource
|
| 40 |
def get_model(path, state_dim):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
if not os.path.exists(path):
|
| 42 |
st.sidebar.warning(f"Model not found at {path}. Using random init for demo.")
|
| 43 |
return HookedDT.from_config(state_dim=state_dim, action_dim=7)
|
| 44 |
|
| 45 |
model = HookedDT.from_config(state_dim=state_dim, action_dim=7)
|
| 46 |
try:
|
| 47 |
-
# Load state dict (
|
| 48 |
model.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
|
| 49 |
model.eval()
|
| 50 |
except Exception as e:
|
|
|
|
| 25 |
|
| 26 |
# Sidebar for loading model and data
|
| 27 |
st.sidebar.header("Data & Model")
|
| 28 |
+
|
| 29 |
+
# List available models in a secure dropdown to prevent Path Traversal
|
| 30 |
+
models_dir = Path("models")
|
| 31 |
+
available_models = []
|
| 32 |
+
if models_dir.exists():
|
| 33 |
+
available_models = [str(p) for p in models_dir.glob("*.pt")]
|
| 34 |
+
if not available_models:
|
| 35 |
+
available_models = ["models/mini_dt.pt"]
|
| 36 |
+
model_path = st.sidebar.selectbox("Select Model Path", sorted(available_models))
|
| 37 |
+
|
| 38 |
+
# List available datasets in a secure dropdown to prevent Path Traversal
|
| 39 |
+
data_dir = Path("data")
|
| 40 |
+
available_data = []
|
| 41 |
+
if data_dir.exists():
|
| 42 |
+
available_data = [str(p) for p in data_dir.glob("*.pt")]
|
| 43 |
+
if not available_data:
|
| 44 |
+
available_data = ["data/trajectories.pt"]
|
| 45 |
+
data_path = st.sidebar.selectbox("Select Trajectory Path", sorted(available_data))
|
| 46 |
+
|
| 47 |
+
# Validation check to guarantee path safety (Defense-in-depth)
|
| 48 |
+
def is_safe_path(base_dir, path):
|
| 49 |
+
base_abs = Path(base_dir).resolve()
|
| 50 |
+
path_abs = Path(path).resolve()
|
| 51 |
+
return path_abs.parts[:len(base_abs.parts)] == base_abs.parts
|
| 52 |
|
| 53 |
@st.cache_data
|
| 54 |
def get_data(path):
|
| 55 |
+
if not is_safe_path("data", path):
|
| 56 |
+
st.sidebar.error("Access Denied: Invalid trajectory path.")
|
| 57 |
+
st.stop()
|
| 58 |
+
|
| 59 |
if not os.path.exists(path):
|
| 60 |
st.sidebar.warning(f"Data not found at {path}. Please run training script.")
|
| 61 |
return None
|
|
|
|
| 64 |
|
| 65 |
@st.cache_resource
|
| 66 |
def get_model(path, state_dim):
|
| 67 |
+
if not is_safe_path("models", path):
|
| 68 |
+
st.sidebar.error("Access Denied: Invalid model path.")
|
| 69 |
+
st.stop()
|
| 70 |
+
|
| 71 |
if not os.path.exists(path):
|
| 72 |
st.sidebar.warning(f"Model not found at {path}. Using random init for demo.")
|
| 73 |
return HookedDT.from_config(state_dim=state_dim, action_dim=7)
|
| 74 |
|
| 75 |
model = HookedDT.from_config(state_dim=state_dim, action_dim=7)
|
| 76 |
try:
|
| 77 |
+
# Load state dict (safe for weights_only=True)
|
| 78 |
model.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
|
| 79 |
model.eval()
|
| 80 |
except Exception as e:
|