nikravan commited on
Commit
8615e88
Β·
verified Β·
1 Parent(s): eb06650

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -3,16 +3,29 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from spaces import GPU
5
 
6
- model_name = "microsoft/DialoGPT-small"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @GPU
14
  def generate_response(message, history):
15
- input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device)
 
16
  chat_history_ids = input_ids
17
  response_ids = model.generate(
18
  chat_history_ids,
@@ -21,7 +34,10 @@ def generate_response(message, history):
21
  do_sample=True,
22
  temperature=0.7
23
  )
24
- response = tokenizer.decode(response_ids[:, chat_history_ids.shape[-1]:][0], skip_special_tokens=True)
 
 
 
25
  return response.strip()
26
 
27
  chatbot = gr.ChatInterface(
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from spaces import GPU
5
 
 
 
 
6
 
7
+
8
+ model_id = "sapientinc/HRM-Text-1B"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_id,
12
+ dtype=torch.bfloat16,
13
+ trust_remote_code=True,
14
+ ).cuda().eval()
15
+
16
+ # synth,cot composite β€” reasoning / CoT style (see Disclaimer for other modes)
17
+ condition = "<|quad_end|><|object_ref_end|>"
18
+ prompt = f"<|im_start|>{condition}Explain why the sky is blue.<|im_end|>"
19
+
20
+
21
+ # M#ark the prompt as a single bidirectional prefix block β€” see "PrefixLM mask" below.
22
+
23
+
24
 
25
  @GPU
26
  def generate_response(message, history):
27
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
28
+ input_ids["token_type_ids"] = torch.ones_like(input_ids["input_ids"])
29
  chat_history_ids = input_ids
30
  response_ids = model.generate(
31
  chat_history_ids,
 
34
  do_sample=True,
35
  temperature=0.7
36
  )
37
+
38
+ with torch.no_grad():
39
+ out = model.generate(**chat_history_ids, max_new_tokens=256, do_sample=False)
40
+ response = tokenizer.decode(out[0], skip_special_tokens=False)
41
  return response.strip()
42
 
43
  chatbot = gr.ChatInterface(