| import gradio as gr |
| import os, gc, copy, torch |
| from datetime import datetime |
| from huggingface_hub import hf_hub_download |
| from pynvml import * |
| from duckduckgo_search import DDGS |
| import re |
| import asyncio |
|
|
| |
| HAS_GPU = False |
|
|
| |
| ctx_limit = 20000 |
| title = "RWKV-5-World-3B-v2-20231025-ctx4096 with RAG" |
| model_file = "rwkv-5-h-world-7B" |
|
|
| |
| try: |
| nvmlInit() |
| GPU_COUNT = nvmlDeviceGetCount() |
| if GPU_COUNT > 0: |
| HAS_GPU = True |
| gpu_h = nvmlDeviceGetHandleByIndex(0) |
| print(f"检测到 {GPU_COUNT} 个GPU设备") |
| for i in range(GPU_COUNT): |
| handle = nvmlDeviceGetHandleByIndex(i) |
| info = nvmlDeviceGetMemoryInfo(handle) |
| name = nvmlDeviceGetName(handle) |
| print(f"GPU {i}: {name}, 总内存: {info.total / 1024**3:.2f} GB") |
| except NVMLError as error: |
| print(error) |
|
|
|
|
| os.environ["RWKV_JIT_ON"] = '1' |
|
|
| |
| MODEL_STRAT="cpu bf16" |
| os.environ["RWKV_CUDA_ON"] = '0' |
|
|
| |
| if HAS_GPU == True: |
| os.environ["RWKV_CUDA_ON"] = '1' |
| if GPU_COUNT >= 2: |
| |
| MODEL_STRAT = "cuda:0 fp16 *10 -> cuda:1 fp16" |
| print(f"使用多GPU策略: {MODEL_STRAT}") |
| else: |
| MODEL_STRAT = "cuda fp16" |
| print(f"使用单GPU策略: {MODEL_STRAT}") |
|
|
| |
| from rwkv.model import RWKV |
| model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth") |
| print(f"加载模型: {model_path}") |
| model = RWKV(model=model_path, strategy=MODEL_STRAT) |
| from rwkv.utils import PIPELINE, PIPELINE_ARGS |
| pipeline = PIPELINE(model, "rwkv_vocab_v20230424") |
|
|
| |
| async def understanding_question(question: str): |
| |
| question = question.lower() |
| question = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', question) |
| |
| return question |
|
|
| |
| async def run_duckduckgo_search_tool(question: str): |
| text = await understanding_question(question) |
| |
| keywords = text.split(",") |
| headers = { |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0" |
| } |
| |
| results = DDGS(headers=headers).text(keywords[0], max_results=5) |
| print(results) |
| |
| return text |
|
|
| |
| def web_search(query, max_results=3): |
| try: |
| |
| headers = { |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0" |
| } |
| |
| with DDGS(headers=headers) as ddgs: |
| results = list(ddgs.text(query, max_results=max_results)) |
| if not results: |
| return "No search results found." |
| |
| formatted_results = "\n\nSearch Results:\n" |
| for i, result in enumerate(results, 1): |
| formatted_results += f"[{i}] {result['title']}\n" |
| formatted_results += f"URL: {result['href']}\n" |
| formatted_results += f"Summary: {result['body']}\n\n" |
| |
| return formatted_results |
| except Exception as e: |
| print(f"Search error: {e}") |
| return "" |
|
|
| |
| def extract_search_query(text): |
| |
| text = text.lower() |
| |
| text = re.sub(r'user:\s*|a:\s*', '', text) |
| |
| text = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', text) |
| |
| return text[:100] |
|
|
| |
| def generate_prompt(instruction, input=""): |
| instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') |
| input = input.strip().replace('\r\n','\n').replace('\n\n','\n') |
| if input: |
| return f"""Instruction: {instruction} |
| |
| Input: {input} |
| |
| Response:""" |
| else: |
| return f"""User: hi |
| |
| A: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. |
| |
| User: {instruction} |
| |
| A:""" |
|
|
| |
| def evaluate( |
| ctx, |
| token_count=200, |
| temperature=1.0, |
| top_p=0.7, |
| presencePenalty = 0.1, |
| countPenalty = 0.1, |
| ): |
| |
| search_query = extract_search_query(ctx) |
| |
| |
| search_results = "" |
| if len(search_query) > 5 and not ctx.startswith("Assistant:"): |
| search_results = web_search(search_query) |
| |
| |
| if search_results: |
| |
| if "User:" in ctx and "\n\nA:" in ctx: |
| |
| parts = ctx.split("\n\nA:") |
| rag_ctx = parts[0] + "\n\nRelevant Information:" + search_results + "\n\nA:" |
| |
| elif "Instruction:" in ctx and "\n\nResponse:" in ctx: |
| |
| parts = ctx.split("\n\nResponse:") |
| rag_ctx = parts[0] + "\n\nRelevant Information:" + search_results + "\n\nResponse:" |
| else: |
| |
| rag_ctx = ctx + "\n\nRelevant Information:" + search_results |
| else: |
| rag_ctx = ctx |
| |
| print("Context with RAG:") |
| print(rag_ctx) |
| |
| args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), |
| alpha_frequency = countPenalty, |
| alpha_presence = presencePenalty, |
| token_ban = [], |
| token_stop = [0]) |
| ctx = rag_ctx.strip() |
| all_tokens = [] |
| out_last = 0 |
| out_str = '' |
| occurrence = {} |
| state = None |
| for i in range(int(token_count)): |
| out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) |
| for n in occurrence: |
| out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) |
|
|
| token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) |
| if token in args.token_stop: |
| break |
| all_tokens += [token] |
| for xxx in occurrence: |
| occurrence[xxx] *= 0.996 |
| if token not in occurrence: |
| occurrence[token] = 1 |
| else: |
| occurrence[token] += 1 |
| |
| tmp = pipeline.decode(all_tokens[out_last:]) |
| if '\ufffd' not in tmp: |
| out_str += tmp |
| yield out_str.strip() |
| out_last = i + 1 |
|
|
| if HAS_GPU == True : |
| gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) |
| print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}') |
| |
| del out |
| del state |
| gc.collect() |
|
|
| if HAS_GPU == True : |
| |
| if HAS_GPU == True: |
| if GPU_COUNT >= 2: |
| |
| for i in range(GPU_COUNT): |
| with torch.cuda.device(f"cuda:{i}"): |
| torch.cuda.empty_cache() |
| if i < 2: |
| handle = nvmlDeviceGetHandleByIndex(i) |
| gpu_info = nvmlDeviceGetMemoryInfo(handle) |
| print(f'GPU {i} VRAM: 总计 {gpu_info.total/(1024**3):.2f}GB, 已用 {gpu_info.used/(1024**3):.2f}GB, 空闲 {gpu_info.free/(1024**3):.2f}GB') |
| else: |
| gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) |
| print(f'GPU VRAM: 总计 {gpu_info.total/(1024**3):.2f}GB, 已用 {gpu_info.used/(1024**3):.2f}GB, 空闲 {gpu_info.free/(1024**3):.2f}GB') |
| torch.cuda.empty_cache() |
| |
| yield out_str.strip() |
|
|
| |
| examples = [ |
| ["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1], |
| ["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1], |
| [generate_prompt("Tell me about ravens."), 333, 1, 0.3, 0, 1], |
| [generate_prompt("Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires."), 333, 1, 0.3, 0, 1], |
| [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 333, 1, 0.3, 0, 1], |
| [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1], |
| ["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1], |
| ['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world. |
| |
| User: Hello Edward. What have you been up to recently? |
| |
| Edward:''', 333, 1, 0.3, 0, 1], |
| [generate_prompt("What are the latest developments in quantum computing?"), 333, 1, 0.3, 0, 1], |
| [generate_prompt("Tell me about the current situation in Ukraine."), 333, 1, 0.3, 0, 1], |
| ] |
|
|
| |
| port=7860 |
| use_frpc=True |
| frpconfigfile="7680.ini" |
| import subprocess |
|
|
| def install_Frpc(port, frpconfigfile, use_frpc): |
| if use_frpc: |
| subprocess.run(['chmod', '+x', './frpc'], check=True) |
| print(f'正在启动frp ,端口{port}') |
| subprocess.Popen(['./frpc', '-c', frpconfigfile]) |
|
|
| install_Frpc('7860',frpconfigfile,use_frpc) |
|
|
| |
| with gr.Blocks(title=title) as demo: |
| gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 with RAG - {title}</h1>\n</div>") |
| with gr.Tab("Raw Generation"): |
| gr.Markdown(f"这是带有RAG功能的RWKV-5 World v2模型。支持100多种世界语言和代码。演示限制上下文长度为{ctx_limit}。") |
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox(lines=2, label="提示词", value="") |
| token_count = gr.Slider(0, 20000, label="最大Token数", step=200, value=100) |
| temperature = gr.Slider(0.2, 2.0, label="温度", step=0.1, value=1.0) |
| top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3) |
| presence_penalty = gr.Slider(0.0, 1.0, label="存在惩罚", step=0.1, value=1) |
| count_penalty = gr.Slider(0.0, 1.0, label="计数惩罚", step=0.1, value=1) |
| with gr.Column(): |
| with gr.Row(): |
| submit = gr.Button("提交", variant="primary") |
| stop_btn = gr.Button("中断", variant="stop") |
| clear = gr.Button("清除", variant="secondary") |
| output = gr.Textbox(label="输出", lines=5) |
| data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], label="示例指令", headers=["提示词", "最大Token数", "温度", "Top P", "存在惩罚", "计数惩罚"]) |
| |
| |
| submit_event = submit.click( |
| evaluate, |
| [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], |
| [output] |
| ) |
| |
| |
| stop_btn.click( |
| fn=None, |
| inputs=None, |
| outputs=None, |
| cancels=[submit_event] |
| ) |
| |
| clear.click(lambda: None, [], [output]) |
| data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty]) |
|
|
| |
| demo.launch(share=False) |
|
|