Vasanthakumar R commited on
Commit
f244f86
Β·
1 Parent(s): 9959ec6

feat: add ZeroGPU support via @spaces.GPU decorator

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -9,6 +9,7 @@ Deploy:
9
  """
10
 
11
  import os
 
12
  import gradio as gr
13
  import torch
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -23,16 +24,13 @@ SYSTEM_PROMPT = (
23
  "You ALWAYS provide a response β€” never return empty output."
24
  )
25
 
26
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
- dtype = torch.float16 if device != "cpu" else torch.float32
28
-
29
- print(f"Loading {MODEL_ID} on {device}...")
30
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
31
  if tokenizer.pad_token is None:
32
  tokenizer.pad_token = tokenizer.eos_token
33
 
34
  model = AutoModelForCausalLM.from_pretrained(
35
- MODEL_ID, torch_dtype=dtype, device_map="auto", trust_remote_code=True,
36
  )
37
 
38
  if ADAPTER_ID:
@@ -43,7 +41,8 @@ model.eval()
43
  print("Model ready!")
44
 
45
 
46
- # ── Inference ───────────────────────────────────────────────────
 
47
  def scan_code(language: str, code: str, max_tokens: int = 1024) -> str:
48
  if not code.strip():
49
  return "Paste some code to scan."
 
9
  """
10
 
11
  import os
12
+ import spaces
13
  import gradio as gr
14
  import torch
15
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
24
  "You ALWAYS provide a response β€” never return empty output."
25
  )
26
 
27
+ print(f"Loading {MODEL_ID}...")
 
 
 
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
29
  if tokenizer.pad_token is None:
30
  tokenizer.pad_token = tokenizer.eos_token
31
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
+ MODEL_ID, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True,
34
  )
35
 
36
  if ADAPTER_ID:
 
41
  print("Model ready!")
42
 
43
 
44
+ # ── Inference (GPU allocated only during this call) ─────────────
45
+ @spaces.GPU(duration=120)
46
  def scan_code(language: str, code: str, max_tokens: int = 1024) -> str:
47
  if not code.strip():
48
  return "Paste some code to scan."