| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """访问文本模型的命令行界面""" |
|
|
| import argparse |
| import os |
| from openai import OpenAI |
| import gradio as gr |
| import random |
| random.seed(42) |
|
|
| CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
| |
| SYSTEM_PROMPT = "你是一个有帮助的AI助手,能够回答用户的问题并提供帮助。" |
|
|
| |
| openai_api_key = "jiayi" |
|
|
| aligner_port = 8013 |
| base_port = 8011 |
| aligner_api_base = f"http://0.0.0.0:{aligner_port}/v1" |
| base_api_base = f"http://0.0.0.0:{base_port}/v1" |
|
|
|
|
| |
|
|
| |
| aligner_model = "" |
| base_model = "" |
|
|
| aligner_client = OpenAI( |
| api_key = openai_api_key, |
| base_url = aligner_api_base, |
| ) |
|
|
| base_client = OpenAI( |
| api_key = openai_api_key, |
| base_url = base_api_base, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| TEXT_EXAMPLES = [ |
| "介绍一下北京大学的历史", |
| "解释一下什么是深度学习", |
| "写一首关于春天的诗", |
| ] |
|
|
| |
| |
| |
| |
| |
|
|
| def text_conversation(text: str, role: str = 'user'): |
| """创建单条文本消息""" |
| return [{'role': role, 'content': text}] |
|
|
|
|
| def question_answering(message: str, history: list): |
| """处理文本问答(流式输出)""" |
| conversation = text_conversation(SYSTEM_PROMPT, 'system') |
| |
| |
| for past_user_msg, past_bot_msg in history: |
| if past_user_msg: |
| conversation.extend(text_conversation(past_user_msg, 'user')) |
| if past_bot_msg: |
| conversation.extend(text_conversation(past_bot_msg, 'assistant')) |
| |
| |
| current_question = message |
| conversation.extend(text_conversation(current_question)) |
| |
| |
| stream = base_client.chat.completions.create( |
| model=base_model, |
| stream=True, |
| messages=conversation, |
| ) |
|
|
| |
| total_answer = "" |
| base_section = "🌟 **原始回答:**\n" |
| total_answer += base_section |
| |
| base_answer = "" |
| yield total_answer |
| for chunk in stream: |
| if chunk.choices[0].delta.content is not None: |
| base_answer += chunk.choices[0].delta.content |
| total_answer += chunk.choices[0].delta.content |
| yield f"```bash\n{base_section}{base_answer}\n```" |
| |
| |
| aligner_section = "\n**Aligner 修正中...**\n\n🌟 **修正后回答:**\n" |
| |
| |
| total_answer = f"```bash\n{base_section}{base_answer}\n```{aligner_section}" |
| yield total_answer |
| |
| aligner_conversation = text_conversation(SYSTEM_PROMPT,'system') |
| aligner_current_question = f'##Question: {current_question}\n##Answer: {base_answer}\n##Correction: ' |
| aligner_conversation.extend(text_conversation(aligner_current_question)) |
| aligner_stream = aligner_client.chat.completions.create( |
| model=aligner_model, |
| stream=True, |
| messages=aligner_conversation, |
| ) |
| |
| aligner_answer = "" |
| for chunk in aligner_stream: |
| if chunk.choices[0].delta.content is not None: |
| aligner_answer += chunk.choices[0].delta.content |
| aligner_answer = aligner_answer.replace('##CORRECTION:', '') |
| yield f"```bash\n{base_section}{base_answer}\n```{aligner_section}{aligner_answer}" |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--port", type=int, default=7860, help="Gradio服务端口") |
| parser.add_argument("--share", default='True',action="store_true", help="是否创建公共链接") |
| parser.add_argument("--api-only", default='False',action="store_true", help="只输出Python API调用示例") |
| args = parser.parse_args() |
| |
| |
| |
| |
| |
| |
| iface = gr.ChatInterface( |
| fn=question_answering, |
| title='Aligner', |
| description='网络安全 Aligner', |
| examples=TEXT_EXAMPLES, |
| theme=gr.themes.Soft( |
| text_size='lg', |
| spacing_size='lg', |
| radius_size='lg', |
| ), |
| ) |
|
|
| iface.launch(server_port=args.port, share=args.share) |