sd / 7b_rag.py
decula
changed 2 gpu t4
43735ce
import gradio as gr
import os, gc, copy, torch
from datetime import datetime
from huggingface_hub import hf_hub_download
from pynvml import *
from duckduckgo_search import DDGS
import re
import asyncio
# Flag to check if GPU is present
HAS_GPU = False
# Model title and context size limit
ctx_limit = 20000
title = "RWKV-5-World-3B-v2-20231025-ctx4096 with RAG"
model_file = "rwkv-5-h-world-7B"
# Get the GPU count
try:
nvmlInit()
GPU_COUNT = nvmlDeviceGetCount()
if GPU_COUNT > 0:
HAS_GPU = True
gpu_h = nvmlDeviceGetHandleByIndex(0)
print(f"检测到 {GPU_COUNT} 个GPU设备")
for i in range(GPU_COUNT):
handle = nvmlDeviceGetHandleByIndex(i)
info = nvmlDeviceGetMemoryInfo(handle)
name = nvmlDeviceGetName(handle)
print(f"GPU {i}: {name}, 总内存: {info.total / 1024**3:.2f} GB")
except NVMLError as error:
print(error)
os.environ["RWKV_JIT_ON"] = '1'
# Model strat to use
MODEL_STRAT="cpu bf16"
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
# Switch to GPU mode
if HAS_GPU == True:
os.environ["RWKV_CUDA_ON"] = '1'
if GPU_COUNT >= 2:
# 使用两块GPU进行模型加载
MODEL_STRAT = "cuda:0 fp16 *10 -> cuda:1 fp16"
print(f"使用多GPU策略: {MODEL_STRAT}")
else:
MODEL_STRAT = "cuda fp16"
print(f"使用单GPU策略: {MODEL_STRAT}")
# Load the model accordingly
from rwkv.model import RWKV
model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth")
print(f"加载模型: {model_path}")
model = RWKV(model=model_path, strategy=MODEL_STRAT)
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
# 理解问题并提取关键词的函数
async def understanding_question(question: str):
# 简单处理:移除常见的问题词,保留关键内容
question = question.lower()
question = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', question)
# 返回处理后的问题作为关键词
return question
# Web search function for RAG with browser agent HTTP headers
async def run_duckduckgo_search_tool(question: str):
text = await understanding_question(question)
keywords = text.split(",")
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0"
}
results = DDGS(headers=headers).text(keywords[0], max_results=5)
print(results)
return text
# 修改后的web_search函数,使用run_duckduckgo_search_tool
def web_search(query, max_results=3):
try:
# 设置浏览器代理HTTP头部
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0"
}
with DDGS(headers=headers) as ddgs:
results = list(ddgs.text(query, max_results=max_results))
if not results:
return "No search results found."
formatted_results = "\n\nSearch Results:\n"
for i, result in enumerate(results, 1):
formatted_results += f"[{i}] {result['title']}\n"
formatted_results += f"URL: {result['href']}\n"
formatted_results += f"Summary: {result['body']}\n\n"
return formatted_results
except Exception as e:
print(f"Search error: {e}")
return ""
# Extract search query from user input
def extract_search_query(text):
# Look for questions or information requests in the text
text = text.lower()
# Remove any existing "User:" or "A:" prefixes
text = re.sub(r'user:\s*|a:\s*', '', text)
# Remove common question words that might not be relevant to the search
text = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', text)
# Limit query length
return text[:100]
# Prompt generation
def generate_prompt(instruction, input=""):
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
if input:
return f"""Instruction: {instruction}
Input: {input}
Response:"""
else:
return f"""User: hi
A: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
User: {instruction}
A:"""
# Evaluation logic with RAG enhancement
def evaluate(
ctx,
token_count=200,
temperature=1.0,
top_p=0.7,
presencePenalty = 0.1,
countPenalty = 0.1,
):
# Extract a search query from the user's input
search_query = extract_search_query(ctx)
# Perform web search if the context seems like a question
search_results = ""
if len(search_query) > 5 and not ctx.startswith("Assistant:"):
search_results = web_search(search_query)
# Combine original context with search results for RAG
if search_results:
# For prompts using the generate_prompt format
if "User:" in ctx and "\n\nA:" in ctx:
# Insert search results before the "A:" part
parts = ctx.split("\n\nA:")
rag_ctx = parts[0] + "\n\nRelevant Information:" + search_results + "\n\nA:"
# For instruction format
elif "Instruction:" in ctx and "\n\nResponse:" in ctx:
# Insert search results before the "Response:" part
parts = ctx.split("\n\nResponse:")
rag_ctx = parts[0] + "\n\nRelevant Information:" + search_results + "\n\nResponse:"
else:
# For other formats, append to the end
rag_ctx = ctx + "\n\nRelevant Information:" + search_results
else:
rag_ctx = ctx
print("Context with RAG:")
print(rag_ctx)
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
token_ban = [], # ban the generation of some tokens
token_stop = [0]) # stop generation whenever you see any token here
ctx = rag_ctx.strip()
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(int(token_count)):
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
break
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= 0.996
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str.strip()
out_last = i + 1
if HAS_GPU == True :
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
del out
del state
gc.collect()
if HAS_GPU == True :
# 在evaluate函数结束部分添加GPU内存清理
if HAS_GPU == True:
if GPU_COUNT >= 2:
# 清理两块GPU的缓存
for i in range(GPU_COUNT):
with torch.cuda.device(f"cuda:{i}"):
torch.cuda.empty_cache()
if i < 2: # 只显示前两块GPU的信息
handle = nvmlDeviceGetHandleByIndex(i)
gpu_info = nvmlDeviceGetMemoryInfo(handle)
print(f'GPU {i} VRAM: 总计 {gpu_info.total/(1024**3):.2f}GB, 已用 {gpu_info.used/(1024**3):.2f}GB, 空闲 {gpu_info.free/(1024**3):.2f}GB')
else:
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
print(f'GPU VRAM: 总计 {gpu_info.total/(1024**3):.2f}GB, 已用 {gpu_info.used/(1024**3):.2f}GB, 空闲 {gpu_info.free/(1024**3):.2f}GB')
torch.cuda.empty_cache()
yield out_str.strip()
# Examples and gradio blocks
examples = [
["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1],
["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1],
[generate_prompt("Tell me about ravens."), 333, 1, 0.3, 0, 1],
[generate_prompt("Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires."), 333, 1, 0.3, 0, 1],
[generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 333, 1, 0.3, 0, 1],
[generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1],
["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1],
['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world.
User: Hello Edward. What have you been up to recently?
Edward:''', 333, 1, 0.3, 0, 1],
[generate_prompt("What are the latest developments in quantum computing?"), 333, 1, 0.3, 0, 1],
[generate_prompt("Tell me about the current situation in Ukraine."), 333, 1, 0.3, 0, 1],
]
##########################################################################
port=7860
use_frpc=True
frpconfigfile="7680.ini"
import subprocess
def install_Frpc(port, frpconfigfile, use_frpc):
if use_frpc:
subprocess.run(['chmod', '+x', './frpc'], check=True)
print(f'正在启动frp ,端口{port}')
subprocess.Popen(['./frpc', '-c', frpconfigfile])
install_Frpc('7860',frpconfigfile,use_frpc)
# Gradio blocks
with gr.Blocks(title=title) as demo:
gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 with RAG - {title}</h1>\n</div>")
with gr.Tab("Raw Generation"):
gr.Markdown(f"这是带有RAG功能的RWKV-5 World v2模型。支持100多种世界语言和代码。演示限制上下文长度为{ctx_limit}。")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=2, label="提示词", value="")
token_count = gr.Slider(0, 20000, label="最大Token数", step=200, value=100)
temperature = gr.Slider(0.2, 2.0, label="温度", step=0.1, value=1.0)
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
presence_penalty = gr.Slider(0.0, 1.0, label="存在惩罚", step=0.1, value=1)
count_penalty = gr.Slider(0.0, 1.0, label="计数惩罚", step=0.1, value=1)
with gr.Column():
with gr.Row():
submit = gr.Button("提交", variant="primary")
stop_btn = gr.Button("中断", variant="stop")
clear = gr.Button("清除", variant="secondary")
output = gr.Textbox(label="输出", lines=5)
data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], label="示例指令", headers=["提示词", "最大Token数", "温度", "Top P", "存在惩罚", "计数惩罚"])
# 设置提交按钮事件
submit_event = submit.click(
evaluate,
[prompt, token_count, temperature, top_p, presence_penalty, count_penalty],
[output]
)
# 设置中断按钮事件
stop_btn.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event] # 取消正在进行的生成过程
)
clear.click(lambda: None, [], [output])
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
# Gradio launch
demo.launch(share=False)