VoyagerXHF commited on
Commit
d73a341
Β·
verified Β·
1 Parent(s): 66ecb01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -11,7 +11,13 @@ from collections import OrderedDict
11
  import gradio as gr
12
  import spaces
13
  import torch
14
- from huggingface_hub import hf_hub_download, snapshot_download
 
 
 
 
 
 
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
16
 
17
  # ─── CLI arguments ────────────────────────────────────────────────────────────
@@ -162,9 +168,9 @@ def get_model():
162
  if _model is None:
163
  print("Loading model…")
164
  _model = AutoModelForCausalLM.from_pretrained(
165
- MODEL_PATH, device_map='auto', torch_dtype='auto'
166
  )
167
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
168
  _model.eval()
169
  print("Model ready.")
170
  return _model, _tokenizer
@@ -183,7 +189,7 @@ def get_sae(layer: int) -> dict:
183
  else:
184
  # Assume HF Hub repo ID – download once, then read from local cache.
185
  if _sae_local_dir is None:
186
- _sae_local_dir = snapshot_download(SAE_PATH, cache_dir='./sae_cache', local_files_only=False)
187
  path = os.path.join(_sae_local_dir, f'layer{layer}.sae.pt')
188
  try:
189
  sae = torch.load(path, map_location=SAE_DEVICE, weights_only=True)
 
11
  import gradio as gr
12
  import spaces
13
  import torch
14
+ from huggingface_hub import hf_hub_download, snapshot_download, login
15
+
16
+ # Login to HuggingFace Hub if HF_TOKEN is set (required for private repos)
17
+ _hf_token = os.environ.get('HF_TOKEN')
18
+ if _hf_token:
19
+ login(token=_hf_token)
20
+ print("Logged in to HuggingFace Hub with HF_TOKEN.")
21
  from transformers import AutoModelForCausalLM, AutoTokenizer
22
 
23
  # ─── CLI arguments ────────────────────────────────────────────────────────────
 
168
  if _model is None:
169
  print("Loading model…")
170
  _model = AutoModelForCausalLM.from_pretrained(
171
+ MODEL_PATH, device_map='auto', torch_dtype='auto', token=_hf_token
172
  )
173
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=_hf_token)
174
  _model.eval()
175
  print("Model ready.")
176
  return _model, _tokenizer
 
189
  else:
190
  # Assume HF Hub repo ID – download once, then read from local cache.
191
  if _sae_local_dir is None:
192
+ _sae_local_dir = snapshot_download(SAE_PATH, cache_dir='./sae_cache', local_files_only=False, token=_hf_token)
193
  path = os.path.join(_sae_local_dir, f'layer{layer}.sae.pt')
194
  try:
195
  sae = torch.load(path, map_location=SAE_DEVICE, weights_only=True)