usermma commited on
Commit
a3a93be
·
verified ·
1 Parent(s): 36565b5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI, Request
3
+ from fastapi.responses import HTMLResponse
4
+ from fastapi.templating import Jinja2Templates
5
+ from safetensors.torch import load_file
6
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
+
10
+ app = FastAPI()
11
+
12
+ templates = Jinja2Templates(directory=".")
13
+
14
+ print("Loading nanoWhale-100m model...")
15
+ config = AutoConfig.from_pretrained("HuggingFaceTB/nanowhale-100m", trust_remote_code=True)
16
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).float()
17
+
18
+ weights_path = hf_hub_download("HuggingFaceTB/nanowhale-100m", "model.safetensors")
19
+ state_dict = load_file(weights_path)
20
+ model.load_state_dict(state_dict, strict=True)
21
+ model = model.eval()
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/nanowhale-100m")
24
+
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ model = model.to(device)
27
+ print(f"Model loaded on {device}")
28
+
29
+ @app.get("/", response_class=HTMLResponse)
30
+ async def get_index(request: Request):
31
+ return templates.TemplateResponse("index.html", {"request": request})
32
+
33
+ @app.post("/generate")
34
+ async def generate_text(request: Request):
35
+ data = await request.json()
36
+ user_prompt = data.get("prompt", "")
37
+
38
+ if not user_prompt:
39
+ return {"error": "No prompt provided"}
40
+
41
+ try:
42
+ messages = [{"role": "user", "content": user_prompt}]
43
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
44
+
45
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
46
+
47
+ with torch.no_grad():
48
+ output = model.generate(
49
+ input_ids,
50
+ max_new_tokens=200,
51
+ temperature=0.7,
52
+ top_p=0.9,
53
+ do_sample=True,
54
+ pad_token_id=tokenizer.eos_token_id
55
+ )
56
+
57
+ generated = output[0][input_ids.shape[1]:]
58
+ response_text = tokenizer.decode(generated, skip_special_tokens=True)
59
+
60
+ return {"response": response_text}
61
+
62
+ except Exception as e:
63
+ return {"error": str(e)}