decula commited on
Commit
3bb71dd
·
1 Parent(s): e0fb7cc

Added 7b_dual.py

Browse files
Files changed (2) hide show
  1. 7b_dual.py +186 -0
  2. gemini_excel_rag.py +44 -13
7b_dual.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, gc, copy, torch
3
+ from datetime import datetime
4
+ from huggingface_hub import hf_hub_download
5
+ from pynvml import *
6
+
7
+ # Flag to check if GPU is present
8
+ HAS_GPU = False
9
+
10
+ # Model title and context size limit
11
+ ctx_limit = 20000
12
+ title = "RWKV-5-World-7B-v2-Dual-GPU"
13
+ model_file = "rwkv-5-h-world-7B"
14
+
15
+ # Get the GPU count
16
+ try:
17
+ nvmlInit()
18
+ GPU_COUNT = nvmlDeviceGetCount()
19
+ if GPU_COUNT > 0:
20
+ HAS_GPU = True
21
+ # 获取主卡句柄用于信息打印
22
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
23
+ except NVMLError as error:
24
+ print(error)
25
+ GPU_COUNT = 0
26
+
27
+ os.environ["RWKV_JIT_ON"] = '1'
28
+
29
+ # --- 核心修改部分:多显卡策略 ---
30
+ # 默认 CPU
31
+ MODEL_STRAT = "cpu bf16"
32
+ os.environ["RWKV_CUDA_ON"] = '0'
33
+
34
+ if HAS_GPU:
35
+ os.environ["RWKV_CUDA_ON"] = '1'
36
+ if GPU_COUNT >= 2:
37
+ # 策略解释:在 cuda:0 上放 16 层,剩下的(约16层+Head)放在 cuda:1 上
38
+ # 使用 bf16 精度(16G 显存运行 7B 绰绰有余)
39
+ MODEL_STRAT = "cuda:0 bf16 * 16 -> cuda:1 bf16"
40
+ print(f"检测到 {GPU_COUNT} 块显卡,启用双卡策略: {MODEL_STRAT}")
41
+ else:
42
+ MODEL_STRAT = "cuda bf16"
43
+ print("检测到单块显卡,启用单卡 bf16")
44
+ # ------------------------------
45
+
46
+ # Load the model accordingly
47
+ from rwkv.model import RWKV
48
+ # 注意:确认 repo_id 是否为 a686d380/rwkv-5-h-world,RWKV-5 通常在 BlinkDL 仓库
49
+ model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth")
50
+ model = RWKV(model=model_path, strategy=MODEL_STRAT)
51
+
52
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
53
+ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
54
+
55
+ # Prompt generation
56
+ def generate_prompt(instruction, input=""):
57
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
58
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
59
+ if input:
60
+ return f"""Instruction: {instruction}
61
+
62
+ Input: {input}
63
+
64
+ Response:"""
65
+ else:
66
+ return f"""User: hi
67
+
68
+ Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
69
+
70
+ User: {instruction}
71
+
72
+ Assistant:"""
73
+
74
+ # Evaluation logic
75
+ def evaluate(
76
+ ctx,
77
+ token_count=200,
78
+ temperature=1.0,
79
+ top_p=0.7,
80
+ presencePenalty = 0.1,
81
+ countPenalty = 0.1,
82
+ ):
83
+ print(ctx)
84
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
85
+ alpha_frequency = countPenalty,
86
+ alpha_presence = presencePenalty,
87
+ token_ban = [], # ban the generation of some tokens
88
+ token_stop = [0]) # stop generation whenever you see any token here
89
+ ctx = ctx.strip()
90
+ all_tokens = []
91
+ out_last = 0
92
+ out_str = ''
93
+ occurrence = {}
94
+ state = None
95
+ for i in range(int(token_count)):
96
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
97
+ for n in occurrence:
98
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
99
+
100
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
101
+ if token in args.token_stop:
102
+ break
103
+ all_tokens += [token]
104
+ for xxx in occurrence:
105
+ occurrence[xxx] *= 0.996
106
+ if token not in occurrence:
107
+ occurrence[token] = 1
108
+ else:
109
+ occurrence[token] += 1
110
+
111
+ tmp = pipeline.decode(all_tokens[out_last:])
112
+ if '\ufffd' not in tmp:
113
+ out_str += tmp
114
+ yield out_str.strip()
115
+ out_last = i + 1
116
+
117
+ if HAS_GPU == True :
118
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
119
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
120
+
121
+ del out
122
+ del state
123
+ gc.collect()
124
+
125
+ if HAS_GPU == True :
126
+ torch.cuda.empty_cache()
127
+
128
+ yield out_str.strip()
129
+
130
+ # Examples and gradio blocks
131
+ examples = [
132
+ ["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1],
133
+ ["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1],
134
+ [generate_prompt("Tell me about ravens."), 333, 1, 0.3, 0, 1],
135
+ [generate_prompt("Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires."), 333, 1, 0.3, 0, 1],
136
+ [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 333, 1, 0.3, 0, 1],
137
+ [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1],
138
+ ["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1],
139
+ ['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world.
140
+
141
+ User: Hello Edward. What have you been up to recently?
142
+
143
+ Edward:''', 333, 1, 0.3, 0, 1],
144
+ [generate_prompt(""), 333, 1, 0.3, 0, 1],
145
+ ['''''', 333, 1, 0.3, 0, 1],
146
+ ]
147
+
148
+ ##########################################################################
149
+ port=7860
150
+ use_frpc=True
151
+ frpconfigfile="7680.ini"
152
+ import subprocess
153
+
154
+ def install_Frpc(port, frpconfigfile, use_frpc):
155
+ if use_frpc:
156
+ subprocess.run(['chmod', '+x', './frpc'], check=True)
157
+ print(f'正在启动frp ,端口{port}')
158
+ subprocess.Popen(['./frpc', '-c', frpconfigfile])
159
+
160
+ install_Frpc('7860',frpconfigfile,use_frpc)
161
+
162
+ # Gradio blocks
163
+ with gr.Blocks(title=title) as demo:
164
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 - {title}</h1>\n</div>")
165
+ with gr.Tab("Raw Generation"):
166
+ gr.Markdown(f"This is RWKV-5 World v2 with 7B params (Dual GPU) - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports all 100+ world languages and code. Demo limited to ctxlen {ctx_limit}.")
167
+ with gr.Row():
168
+ with gr.Column():
169
+ prompt = gr.Textbox(lines=2, label="Prompt", value="")
170
+ token_count = gr.Slider(0, 20000, label="Max Tokens", step=200, value=100)
171
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
172
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
173
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=1)
174
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=1)
175
+ with gr.Column():
176
+ with gr.Row():
177
+ submit = gr.Button("Submit", variant="primary")
178
+ clear = gr.Button("Clear", variant="secondary")
179
+ output = gr.Textbox(label="Output", lines=5)
180
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"], samples=examples)
181
+ submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
182
+ clear.click(lambda: None, [], [output])
183
+ data.click(lambda *x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
184
+
185
+ # Gradio launch
186
+ demo.launch(share=False)
gemini_excel_rag.py CHANGED
@@ -15,7 +15,7 @@ import numpy as np
15
  import re
16
 
17
  # 设置Google API密钥
18
- os.environ["GOOGLE_API_KEY"] = "YOUR_GOOGLE_API_KEY" # 请替换为您的API密钥
19
 
20
  # 设置向量数据库存储路径
21
  VECTOR_STORE_PATH = "./vector_store"
@@ -48,7 +48,7 @@ def get_vectorstore():
48
  def get_llm():
49
  """初始化Gemini Flash 2.0模型"""
50
  return ChatGoogleGenerativeAI(
51
- model="gemini-flash-2.0",
52
  temperature=0.7,
53
  convert_system_message_to_human=True,
54
  max_output_tokens=2048
@@ -110,24 +110,42 @@ def process_excel_with_markitdown(file_path):
110
 
111
  # 使用pandas直接处理Excel文件并添加到向量数据库
112
  def process_excel_with_pandas(file_path):
113
- """使用pandas处理Excel文件并添加到向量数据库"""
114
  try:
115
  # 读取Excel文件
116
  df = pd.read_excel(file_path)
117
 
118
- # 将每个表格行转换为文本
119
  documents = []
120
  for idx, row in df.iterrows():
121
- # 转换为字符串格式
122
  row_text = "\n".join([f"{col}: {val}" for col, val in row.items() if not pd.isna(val)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # 创建文档
124
  doc = Document(
125
  page_content=row_text,
126
- metadata={
127
- "source": file_path,
128
- "row": idx,
129
- "sheet": "Sheet1" # 如果需要处理多个sheet,可以在这里修改
130
- }
131
  )
132
  documents.append(doc)
133
 
@@ -136,7 +154,7 @@ def process_excel_with_pandas(file_path):
136
  vectorstore.add_documents(documents)
137
  vectorstore.persist()
138
 
139
- return f"成功处理Excel文件: {file_path},添加了{len(documents)}个行记录到向量数据库"
140
  except Exception as e:
141
  return f"处理Excel文件时出错: {str(e)}"
142
 
@@ -155,6 +173,19 @@ def answer_question(question):
155
  return response
156
 
157
  # 创建Gradio界面
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def create_interface():
159
  with gr.Blocks(title="Gemini Flash 2.0 Excel RAG") as demo:
160
  gr.HTML("<h1 style='text-align: center'>Gemini Flash 2.0 Excel RAG 系统</h1>")
@@ -162,7 +193,7 @@ def create_interface():
162
  with gr.Tab("导入Excel数据"):
163
  with gr.Row():
164
  excel_file = gr.File(label="上传Excel文件")
165
- process_method = gr.Radio(["使用MarkItDown处理", "使用Pandas处理"], label="处理方法", value="使用MarkItDown处理")
166
  process_btn = gr.Button("处理并导入到向量数据库")
167
  output_msg = gr.Textbox(label="处理结果")
168
 
@@ -218,7 +249,7 @@ def create_interface():
218
 
219
  def main():
220
  demo = create_interface()
221
- demo.launch(share=False)
222
 
223
  if __name__ == "__main__":
224
  main()
 
15
  import re
16
 
17
  # 设置Google API密钥
18
+ os.environ["GOOGLE_API_KEY"] = "AIzaSyAouHZNUHVWoMHPTrTZKCES-OqiosAfEJY" # 请替换为您的API密钥
19
 
20
  # 设置向量数据库存储路径
21
  VECTOR_STORE_PATH = "./vector_store"
 
48
  def get_llm():
49
  """初始化Gemini Flash 2.0模型"""
50
  return ChatGoogleGenerativeAI(
51
+ model="gemini-2.0-flash-exp",
52
  temperature=0.7,
53
  convert_system_message_to_human=True,
54
  max_output_tokens=2048
 
110
 
111
  # 使用pandas直接处理Excel文件并添加到向量数据库
112
  def process_excel_with_pandas(file_path):
113
+ """使用pandas处理Excel文件并添加到向量数据库,将每列作为单独的元数据字段"""
114
  try:
115
  # 读取Excel文件
116
  df = pd.read_excel(file_path)
117
 
118
+ # 将每个表格行转换为文本和元数据
119
  documents = []
120
  for idx, row in df.iterrows():
121
+ # 创建文本内容(用于向量化和检索)
122
  row_text = "\n".join([f"{col}: {val}" for col, val in row.items() if not pd.isna(val)])
123
+
124
+ # 创建元数据字典,包含每列的值
125
+ metadata = {
126
+ "source": file_path,
127
+ "row": idx,
128
+ "sheet": "Sheet1" # 如果需要处理多个sheet,可以在这里修改
129
+ }
130
+
131
+ # 将每列的值添加到元数据中
132
+ for col, val in row.items():
133
+ # 处理不同类型的数据
134
+ if isinstance(val, (int, float)) and not pd.isna(val):
135
+ metadata[f"col_{col}"] = val
136
+ elif isinstance(val, str) and val.strip():
137
+ metadata[f"col_{col}"] = val.strip()
138
+ elif pd.isna(val):
139
+ # 跳过空值
140
+ continue
141
+ else:
142
+ # 其他类型转为字符串
143
+ metadata[f"col_{col}"] = str(val)
144
+
145
  # 创建文档
146
  doc = Document(
147
  page_content=row_text,
148
+ metadata=metadata
 
 
 
 
149
  )
150
  documents.append(doc)
151
 
 
154
  vectorstore.add_documents(documents)
155
  vectorstore.persist()
156
 
157
+ return f"成功处理Excel文件: {file_path},添加了{len(documents)}个行记录到向量数据库,每行包含{len(df.columns)}个字段"
158
  except Exception as e:
159
  return f"处理Excel文件时出错: {str(e)}"
160
 
 
173
  return response
174
 
175
  # 创建Gradio界面
176
+ port=7860
177
+ use_frpc=True
178
+ frpconfigfile="7680.ini"
179
+ import subprocess
180
+
181
+ def install_Frpc(port, frpconfigfile, use_frpc):
182
+ if use_frpc:
183
+ subprocess.run(['chmod', '+x', './frpc'], check=True)
184
+ print(f'正在启动frp ,端口{port}')
185
+ subprocess.Popen(['./frpc', '-c', frpconfigfile])
186
+
187
+ install_Frpc('7860',frpconfigfile,use_frpc)
188
+
189
  def create_interface():
190
  with gr.Blocks(title="Gemini Flash 2.0 Excel RAG") as demo:
191
  gr.HTML("<h1 style='text-align: center'>Gemini Flash 2.0 Excel RAG 系统</h1>")
 
193
  with gr.Tab("导入Excel数据"):
194
  with gr.Row():
195
  excel_file = gr.File(label="上传Excel文件")
196
+ process_method = gr.Radio(["使用MarkItDown处理", "使用Pandas处理"], label="处理方法", value="使用Pandas处理")
197
  process_btn = gr.Button("处理并导入到向量数据库")
198
  output_msg = gr.Textbox(label="处理结果")
199
 
 
249
 
250
  def main():
251
  demo = create_interface()
252
+ demo.launch(share=True)
253
 
254
  if __name__ == "__main__":
255
  main()