import gradio as gr import os, gc, copy, torch from datetime import datetime from huggingface_hub import hf_hub_download from pynvml import * from duckduckgo_search import DDGS import re import asyncio # Flag to check if GPU is present HAS_GPU = False # Model title and context size limit ctx_limit = 20000 title = "RWKV-5-World-3B-v2-20231025-ctx4096 with RAG" model_file = "rwkv-5-h-world-7B" # Get the GPU count try: nvmlInit() GPU_COUNT = nvmlDeviceGetCount() if GPU_COUNT > 0: HAS_GPU = True gpu_h = nvmlDeviceGetHandleByIndex(0) print(f"检测到 {GPU_COUNT} 个GPU设备") for i in range(GPU_COUNT): handle = nvmlDeviceGetHandleByIndex(i) info = nvmlDeviceGetMemoryInfo(handle) name = nvmlDeviceGetName(handle) print(f"GPU {i}: {name}, 总内存: {info.total / 1024**3:.2f} GB") except NVMLError as error: print(error) os.environ["RWKV_JIT_ON"] = '1' # Model strat to use MODEL_STRAT="cpu bf16" os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster) # Switch to GPU mode if HAS_GPU == True: os.environ["RWKV_CUDA_ON"] = '1' if GPU_COUNT >= 2: # 使用两块GPU进行模型加载 MODEL_STRAT = "cuda:0 fp16 *10 -> cuda:1 fp16" print(f"使用多GPU策略: {MODEL_STRAT}") else: MODEL_STRAT = "cuda fp16" print(f"使用单GPU策略: {MODEL_STRAT}") # Load the model accordingly from rwkv.model import RWKV model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth") print(f"加载模型: {model_path}") model = RWKV(model=model_path, strategy=MODEL_STRAT) from rwkv.utils import PIPELINE, PIPELINE_ARGS pipeline = PIPELINE(model, "rwkv_vocab_v20230424") # 理解问题并提取关键词的函数 async def understanding_question(question: str): # 简单处理:移除常见的问题词,保留关键内容 question = question.lower() question = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', question) # 返回处理后的问题作为关键词 return question # Web search function for RAG with browser agent HTTP headers async def run_duckduckgo_search_tool(question: str): text = await understanding_question(question) keywords = text.split(",") headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0" } results = DDGS(headers=headers).text(keywords[0], max_results=5) print(results) return text # 修改后的web_search函数,使用run_duckduckgo_search_tool def web_search(query, max_results=3): try: # 设置浏览器代理HTTP头部 headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0" } with DDGS(headers=headers) as ddgs: results = list(ddgs.text(query, max_results=max_results)) if not results: return "No search results found." formatted_results = "\n\nSearch Results:\n" for i, result in enumerate(results, 1): formatted_results += f"[{i}] {result['title']}\n" formatted_results += f"URL: {result['href']}\n" formatted_results += f"Summary: {result['body']}\n\n" return formatted_results except Exception as e: print(f"Search error: {e}") return "" # Extract search query from user input def extract_search_query(text): # Look for questions or information requests in the text text = text.lower() # Remove any existing "User:" or "A:" prefixes text = re.sub(r'user:\s*|a:\s*', '', text) # Remove common question words that might not be relevant to the search text = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', text) # Limit query length return text[:100] # Prompt generation def generate_prompt(instruction, input=""): instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') input = input.strip().replace('\r\n','\n').replace('\n\n','\n') if input: return f"""Instruction: {instruction} Input: {input} Response:""" else: return f"""User: hi A: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. User: {instruction} A:""" # Evaluation logic with RAG enhancement def evaluate( ctx, token_count=200, temperature=1.0, top_p=0.7, presencePenalty = 0.1, countPenalty = 0.1, ): # Extract a search query from the user's input search_query = extract_search_query(ctx) # Perform web search if the context seems like a question search_results = "" if len(search_query) > 5 and not ctx.startswith("Assistant:"): search_results = web_search(search_query) # Combine original context with search results for RAG if search_results: # For prompts using the generate_prompt format if "User:" in ctx and "\n\nA:" in ctx: # Insert search results before the "A:" part parts = ctx.split("\n\nA:") rag_ctx = parts[0] + "\n\nRelevant Information:" + search_results + "\n\nA:" # For instruction format elif "Instruction:" in ctx and "\n\nResponse:" in ctx: # Insert search results before the "Response:" part parts = ctx.split("\n\nResponse:") rag_ctx = parts[0] + "\n\nRelevant Information:" + search_results + "\n\nResponse:" else: # For other formats, append to the end rag_ctx = ctx + "\n\nRelevant Information:" + search_results else: rag_ctx = ctx print("Context with RAG:") print(rag_ctx) args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), alpha_frequency = countPenalty, alpha_presence = presencePenalty, token_ban = [], # ban the generation of some tokens token_stop = [0]) # stop generation whenever you see any token here ctx = rag_ctx.strip() all_tokens = [] out_last = 0 out_str = '' occurrence = {} state = None for i in range(int(token_count)): out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) for n in occurrence: out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) if token in args.token_stop: break all_tokens += [token] for xxx in occurrence: occurrence[xxx] *= 0.996 if token not in occurrence: occurrence[token] = 1 else: occurrence[token] += 1 tmp = pipeline.decode(all_tokens[out_last:]) if '\ufffd' not in tmp: out_str += tmp yield out_str.strip() out_last = i + 1 if HAS_GPU == True : gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}') del out del state gc.collect() if HAS_GPU == True : # 在evaluate函数结束部分添加GPU内存清理 if HAS_GPU == True: if GPU_COUNT >= 2: # 清理两块GPU的缓存 for i in range(GPU_COUNT): with torch.cuda.device(f"cuda:{i}"): torch.cuda.empty_cache() if i < 2: # 只显示前两块GPU的信息 handle = nvmlDeviceGetHandleByIndex(i) gpu_info = nvmlDeviceGetMemoryInfo(handle) print(f'GPU {i} VRAM: 总计 {gpu_info.total/(1024**3):.2f}GB, 已用 {gpu_info.used/(1024**3):.2f}GB, 空闲 {gpu_info.free/(1024**3):.2f}GB') else: gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) print(f'GPU VRAM: 总计 {gpu_info.total/(1024**3):.2f}GB, 已用 {gpu_info.used/(1024**3):.2f}GB, 空闲 {gpu_info.free/(1024**3):.2f}GB') torch.cuda.empty_cache() yield out_str.strip() # Examples and gradio blocks examples = [ ["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1], ["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1], [generate_prompt("Tell me about ravens."), 333, 1, 0.3, 0, 1], [generate_prompt("Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires."), 333, 1, 0.3, 0, 1], [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 333, 1, 0.3, 0, 1], [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1], ["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1], ['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world. User: Hello Edward. What have you been up to recently? Edward:''', 333, 1, 0.3, 0, 1], [generate_prompt("What are the latest developments in quantum computing?"), 333, 1, 0.3, 0, 1], [generate_prompt("Tell me about the current situation in Ukraine."), 333, 1, 0.3, 0, 1], ] ########################################################################## port=7860 use_frpc=True frpconfigfile="7680.ini" import subprocess def install_Frpc(port, frpconfigfile, use_frpc): if use_frpc: subprocess.run(['chmod', '+x', './frpc'], check=True) print(f'正在启动frp ,端口{port}') subprocess.Popen(['./frpc', '-c', frpconfigfile]) install_Frpc('7860',frpconfigfile,use_frpc) # Gradio blocks with gr.Blocks(title=title) as demo: gr.HTML(f"