decula commited on
Commit
6e2a9d9
·
1 Parent(s): 43a4e1b

Added 7B and 3B

Browse files
Files changed (2) hide show
  1. 3b.py +1 -1
  2. 7b.py +175 -0
3b.py CHANGED
@@ -8,7 +8,7 @@ from pynvml import *
8
  HAS_GPU = False
9
 
10
  # Model title and context size limit
11
- ctx_limit = 2000
12
  title = "RWKV-5-World-3B-v2-20231025-ctx4096"
13
  model_file = "rwkv-5-h-world-3B"
14
 
 
8
  HAS_GPU = False
9
 
10
  # Model title and context size limit
11
+ ctx_limit = 10000
12
  title = "RWKV-5-World-3B-v2-20231025-ctx4096"
13
  model_file = "rwkv-5-h-world-3B"
14
 
7b.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os, gc, copy, torch
3
+ from datetime import datetime
4
+ from huggingface_hub import hf_hub_download
5
+ from pynvml import *
6
+
7
+ # Flag to check if GPU is present
8
+ HAS_GPU = False
9
+
10
+ # Model title and context size limit
11
+ ctx_limit = 20000
12
+ title = "RWKV-5-World-3B-v2-20231025-ctx4096"
13
+ model_file = "rwkv-5-h-world-7B"
14
+
15
+ # Get the GPU count
16
+ try:
17
+ nvmlInit()
18
+ GPU_COUNT = nvmlDeviceGetCount()
19
+ if GPU_COUNT > 0:
20
+ HAS_GPU = True
21
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
22
+ except NVMLError as error:
23
+ print(error)
24
+
25
+
26
+ os.environ["RWKV_JIT_ON"] = '1'
27
+
28
+ # Model strat to use
29
+ MODEL_STRAT="cpu bf16"
30
+ os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
31
+
32
+ # Switch to GPU mode
33
+ if HAS_GPU == True :
34
+ os.environ["RWKV_CUDA_ON"] = '1'
35
+ MODEL_STRAT = "cuda bf16"
36
+
37
+ # Load the model accordingly
38
+ from rwkv.model import RWKV
39
+ model_path = hf_hub_download(repo_id="a686d380/rwkv-5-h-world", filename=f"{model_file}.pth")
40
+ model = RWKV(model=model_path, strategy=MODEL_STRAT)
41
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
42
+ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
43
+
44
+ # Prompt generation
45
+ def generate_prompt(instruction, input=""):
46
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
47
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
48
+ if input:
49
+ return f"""Instruction: {instruction}
50
+
51
+ Input: {input}
52
+
53
+ Response:"""
54
+ else:
55
+ return f"""User: hi
56
+
57
+ Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
58
+
59
+ User: {instruction}
60
+
61
+ Assistant:"""
62
+
63
+ # Evaluation logic
64
+ def evaluate(
65
+ ctx,
66
+ token_count=200,
67
+ temperature=1.0,
68
+ top_p=0.7,
69
+ presencePenalty = 0.1,
70
+ countPenalty = 0.1,
71
+ ):
72
+ print(ctx)
73
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
74
+ alpha_frequency = countPenalty,
75
+ alpha_presence = presencePenalty,
76
+ token_ban = [], # ban the generation of some tokens
77
+ token_stop = [0]) # stop generation whenever you see any token here
78
+ ctx = ctx.strip()
79
+ all_tokens = []
80
+ out_last = 0
81
+ out_str = ''
82
+ occurrence = {}
83
+ state = None
84
+ for i in range(int(token_count)):
85
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
86
+ for n in occurrence:
87
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
88
+
89
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
90
+ if token in args.token_stop:
91
+ break
92
+ all_tokens += [token]
93
+ for xxx in occurrence:
94
+ occurrence[xxx] *= 0.996
95
+ if token not in occurrence:
96
+ occurrence[token] = 1
97
+ else:
98
+ occurrence[token] += 1
99
+
100
+ tmp = pipeline.decode(all_tokens[out_last:])
101
+ if '\ufffd' not in tmp:
102
+ out_str += tmp
103
+ yield out_str.strip()
104
+ out_last = i + 1
105
+
106
+ if HAS_GPU == True :
107
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
108
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
109
+
110
+ del out
111
+ del state
112
+ gc.collect()
113
+
114
+ if HAS_GPU == True :
115
+ torch.cuda.empty_cache()
116
+
117
+ yield out_str.strip()
118
+
119
+ # Examples and gradio blocks
120
+ examples = [
121
+ ["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1],
122
+ ["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1],
123
+ [generate_prompt("Tell me about ravens."), 333, 1, 0.3, 0, 1],
124
+ [generate_prompt("Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires."), 333, 1, 0.3, 0, 1],
125
+ [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 333, 1, 0.3, 0, 1],
126
+ [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1],
127
+ ["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1],
128
+ ['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world.
129
+
130
+ User: Hello Edward. What have you been up to recently?
131
+
132
+ Edward:''', 333, 1, 0.3, 0, 1],
133
+ [generate_prompt(""), 333, 1, 0.3, 0, 1],
134
+ ['''''', 333, 1, 0.3, 0, 1],
135
+ ]
136
+
137
+ ##########################################################################
138
+ port=7860
139
+ use_frpc=True
140
+ frpconfigfile="7680.ini"
141
+ import subprocess
142
+
143
+ def install_Frpc(port, frpconfigfile, use_frpc):
144
+ if use_frpc:
145
+ subprocess.run(['chmod', '+x', './frpc'], check=True)
146
+ print(f'正在启动frp ,端口{port}')
147
+ subprocess.Popen(['./frpc', '-c', frpconfigfile])
148
+
149
+ install_Frpc('7860',frpconfigfile,use_frpc)
150
+
151
+ # Gradio blocks
152
+ with gr.Blocks(title=title) as demo:
153
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 - {title}</h1>\n</div>")
154
+ with gr.Tab("Raw Generation"):
155
+ gr.Markdown(f"This is RWKV-5 World v2 with 3B params - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports all 100+ world languages and code. And we have [200+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Please try examples first (bottom of page) *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}.")
156
+ with gr.Row():
157
+ with gr.Column():
158
+ prompt = gr.Textbox(lines=2, label="Prompt", value="")
159
+ token_count = gr.Slider(0, 10000, label="Max Tokens", step=200, value=100)
160
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
161
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
162
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=1)
163
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=1)
164
+ with gr.Column():
165
+ with gr.Row():
166
+ submit = gr.Button("Submit", variant="primary")
167
+ clear = gr.Button("Clear", variant="secondary")
168
+ output = gr.Textbox(label="Output", lines=5)
169
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
170
+ submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
171
+ clear.click(lambda: None, [], [output])
172
+ data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
173
+
174
+ # Gradio launch
175
+ demo.launch(share=False)