Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,13 @@ from collections import OrderedDict
|
|
| 11 |
import gradio as gr
|
| 12 |
import spaces
|
| 13 |
import torch
|
| 14 |
-
from huggingface_hub import hf_hub_download, snapshot_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 16 |
|
| 17 |
# βββ CLI arguments ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -162,9 +168,9 @@ def get_model():
|
|
| 162 |
if _model is None:
|
| 163 |
print("Loading modelβ¦")
|
| 164 |
_model = AutoModelForCausalLM.from_pretrained(
|
| 165 |
-
MODEL_PATH, device_map='auto', torch_dtype='auto'
|
| 166 |
)
|
| 167 |
-
_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 168 |
_model.eval()
|
| 169 |
print("Model ready.")
|
| 170 |
return _model, _tokenizer
|
|
@@ -183,7 +189,7 @@ def get_sae(layer: int) -> dict:
|
|
| 183 |
else:
|
| 184 |
# Assume HF Hub repo ID β download once, then read from local cache.
|
| 185 |
if _sae_local_dir is None:
|
| 186 |
-
_sae_local_dir = snapshot_download(SAE_PATH, cache_dir='./sae_cache', local_files_only=False)
|
| 187 |
path = os.path.join(_sae_local_dir, f'layer{layer}.sae.pt')
|
| 188 |
try:
|
| 189 |
sae = torch.load(path, map_location=SAE_DEVICE, weights_only=True)
|
|
|
|
| 11 |
import gradio as gr
|
| 12 |
import spaces
|
| 13 |
import torch
|
| 14 |
+
from huggingface_hub import hf_hub_download, snapshot_download, login
|
| 15 |
+
|
| 16 |
+
# Login to HuggingFace Hub if HF_TOKEN is set (required for private repos)
|
| 17 |
+
_hf_token = os.environ.get('HF_TOKEN')
|
| 18 |
+
if _hf_token:
|
| 19 |
+
login(token=_hf_token)
|
| 20 |
+
print("Logged in to HuggingFace Hub with HF_TOKEN.")
|
| 21 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
|
| 23 |
# βββ CLI arguments ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 168 |
if _model is None:
|
| 169 |
print("Loading modelβ¦")
|
| 170 |
_model = AutoModelForCausalLM.from_pretrained(
|
| 171 |
+
MODEL_PATH, device_map='auto', torch_dtype='auto', token=_hf_token
|
| 172 |
)
|
| 173 |
+
_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=_hf_token)
|
| 174 |
_model.eval()
|
| 175 |
print("Model ready.")
|
| 176 |
return _model, _tokenizer
|
|
|
|
| 189 |
else:
|
| 190 |
# Assume HF Hub repo ID β download once, then read from local cache.
|
| 191 |
if _sae_local_dir is None:
|
| 192 |
+
_sae_local_dir = snapshot_download(SAE_PATH, cache_dir='./sae_cache', local_files_only=False, token=_hf_token)
|
| 193 |
path = os.path.join(_sae_local_dir, f'layer{layer}.sae.pt')
|
| 194 |
try:
|
| 195 |
sae = torch.load(path, map_location=SAE_DEVICE, weights_only=True)
|