| import gradio as gr |
| import os, gc, copy, torch |
| from datetime import datetime |
| from pynvml import * |
| from duckduckgo_search import DDGS |
| import re |
| import asyncio |
| import subprocess |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| |
| HAS_GPU = False |
| try: |
| nvmlInit() |
| GPU_COUNT = nvmlDeviceGetCount() |
| if GPU_COUNT > 0: |
| HAS_GPU = True |
| print(f"检测到 {GPU_COUNT} 个GPU设备") |
| for i in range(GPU_COUNT): |
| handle = nvmlDeviceGetHandleByIndex(i) |
| info = nvmlDeviceGetMemoryInfo(handle) |
| name = nvmlDeviceGetName(handle) |
| print(f"GPU {i}: {name}, 总内存: {info.total / 1024**3:.2f} GB") |
| except NVMLError as error: |
| print(error) |
|
|
| |
| model_id = "Orion-zhen/c4ai-command-r-08-2024-h-novel-exl2" |
| ctx_limit = 20000 |
| title = "c4ai ommand r 08 2024 h novel exl2with RAG" |
|
|
| print(f"正在并行加载模型 {model_id} 到多块 GPU...") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| device_map="auto", |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True |
| ) |
|
|
| print(f"模型加载完成。设备映射: {model.hf_device_map}") |
|
|
| |
| async def understanding_question(question: str): |
| question = question.lower() |
| question = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', question) |
| return question |
|
|
| def web_search(query, max_results=3): |
| try: |
| headers = { |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" |
| } |
| with DDGS(headers=headers) as ddgs: |
| results = list(ddgs.text(query, max_results=max_results)) |
| if not results: return "" |
| formatted_results = "\n\nSearch Results:\n" |
| for i, result in enumerate(results, 1): |
| formatted_results += f"[{i}] {result['title']}\nSummary: {result['body']}\n\n" |
| return formatted_results |
| except Exception as e: |
| print(f"Search error: {e}") |
| return "" |
|
|
| def extract_search_query(text): |
| text = text.lower() |
| text = re.sub(r'user:\s*|a:\s*', '', text) |
| text = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', text) |
| return text[:100] |
|
|
| |
| def evaluate( |
| ctx, |
| token_count=200, |
| temperature=1.0, |
| top_p=0.7, |
| presencePenalty=0.1, |
| countPenalty=0.1, |
| ): |
| |
| search_query = extract_search_query(ctx) |
| search_results = "" |
| if len(search_query) > 5: |
| search_results = web_search(search_query) |
| |
| |
| user_content = ctx |
| if search_results: |
| user_content = f"参考信息:\n{search_results}\n\n用户问题:{ctx}" |
| |
| messages = [ |
| {"role": "user", "content": user_content}, |
| ] |
|
|
| |
| inputs = tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ).to(model.device) |
|
|
| |
| gen_kwargs = { |
| "max_new_tokens": int(token_count), |
| "temperature": max(0.1, float(temperature)), |
| "top_p": float(top_p), |
| "repetition_penalty": float(1.0 + presencePenalty), |
| "do_sample": True, |
| "pad_token_id": tokenizer.eos_token_id |
| } |
|
|
| with torch.no_grad(): |
| output = model.generate(**inputs, **gen_kwargs) |
| |
| |
| prompt_len = inputs.input_ids.shape[1] |
| result = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True) |
|
|
| |
| if HAS_GPU: |
| torch.cuda.empty_cache() |
| for i in range(GPU_COUNT): |
| handle = nvmlDeviceGetHandleByIndex(i) |
| gpu_info = nvmlDeviceGetMemoryInfo(handle) |
| print(f'GPU {i} VRAM: 已用 {gpu_info.used/(1024**3):.2f}GB / 总计 {gpu_info.total/(1024**3):.2f}GB') |
| |
| return result |
|
|
| |
| port=7860 |
| use_frpc=True |
| frpconfigfile="7680.ini" |
|
|
| def install_Frpc(port, frpconfigfile, use_frpc): |
| if use_frpc: |
| subprocess.run(['chmod', '+x', './frpc'], check=True) |
| print(f'正在启动frp ,端口{port}') |
| subprocess.Popen(['./frpc', '-c', frpconfigfile]) |
|
|
| install_Frpc('7860', frpconfigfile, use_frpc) |
|
|
| |
| with gr.Blocks(title=title) as demo: |
| gr.HTML(f"<div style=\"text-align: center;\"><h1>{title}</h1></div>") |
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox(lines=5, label="提示词", value="") |
| token_count = gr.Slider(100, 10000, label="最大Token数", step=100, value=500) |
| temperature = gr.Slider(0.2, 2.0, label="温度", step=0.1, value=1.0) |
| top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.7) |
| presence_penalty = gr.Slider(0.0, 1.0, label="存在惩罚", step=0.1, value=0.1) |
| count_penalty = gr.Slider(0.0, 1.0, label="计数惩罚", step=0.1, value=0.1) |
| with gr.Column(): |
| submit = gr.Button("提交", variant="primary") |
| output = gr.Textbox(label="输出", lines=15) |
| clear = gr.Button("清除", variant="secondary") |
|
|
| submit.click( |
| evaluate, |
| [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], |
| [output] |
| ) |
| clear.click(lambda: None, [], [output]) |
|
|
| |
| demo.launch(share=False) |