| |
| import os |
| from functools import partial |
| from typing import List, Union |
|
|
| import gradio as gr |
| from packaging import version |
| from transformers.utils import strtobool |
|
|
| import swift |
| from swift.llm import DeployArguments, EvalArguments, ExportArguments, RLHFArguments, SwiftPipeline, WebUIArguments |
| from swift.ui.llm_eval.llm_eval import LLMEval |
| from swift.ui.llm_export.llm_export import LLMExport |
| from swift.ui.llm_infer.llm_infer import LLMInfer |
| from swift.ui.llm_train.llm_train import LLMTrain |
|
|
| locale_dict = { |
| 'title': { |
| 'zh': '🚀SWIFT: 轻量级大模型训练推理框架', |
| 'en': '🚀SWIFT: Scalable lightWeight Infrastructure for Fine-Tuning and Inference' |
| }, |
| 'sub_title': { |
| 'zh': |
| '请查看 <a href=\"https://github.com/modelscope/swift/tree/main/docs/source\" target=\"_blank\">' |
| 'SWIFT 文档</a>来查看更多功能,使用SWIFT_UI_LANG=en环境变量来切换英文界面', |
| 'en': |
| 'Please check <a href=\"https://github.com/modelscope/swift/tree/main/docs/source_en\" target=\"_blank\">' |
| 'SWIFT Documentation</a> for more usages, Use SWIFT_UI_LANG=zh variable to switch to Chinese UI', |
| }, |
| 'star_beggar': { |
| 'zh': |
| '喜欢<a href=\"https://github.com/modelscope/swift\" target=\"_blank\">SWIFT</a>就动动手指给我们加个star吧🥺 ', |
| 'en': |
| 'If you like <a href=\"https://github.com/modelscope/swift\" target=\"_blank\">SWIFT</a>, ' |
| 'please take a few seconds to star us🥺 ' |
| }, |
| } |
|
|
|
|
| class SwiftWebUI(SwiftPipeline): |
|
|
| args_class = WebUIArguments |
| args: args_class |
|
|
| def run(self): |
| lang = os.environ.get('SWIFT_UI_LANG') or self.args.lang |
| share_env = os.environ.get('WEBUI_SHARE') |
| share = strtobool(share_env) if share_env else self.args.share |
| server = os.environ.get('WEBUI_SERVER') or self.args.server_name |
| port_env = os.environ.get('WEBUI_PORT') |
| port = int(port_env) if port_env else self.args.server_port |
| LLMTrain.set_lang(lang) |
| LLMInfer.set_lang(lang) |
| LLMExport.set_lang(lang) |
| LLMEval.set_lang(lang) |
| with gr.Blocks(title='SWIFT WebUI', theme=gr.themes.Base()) as app: |
| try: |
| _version = swift.__version__ |
| except AttributeError: |
| _version = '' |
| gr.HTML(f"<h1><center>{locale_dict['title'][lang]}({_version})</center></h1>") |
| gr.HTML(f"<h3><center>{locale_dict['sub_title'][lang]}</center></h3>") |
| with gr.Tabs(): |
| LLMTrain.build_ui(LLMTrain) |
| LLMInfer.build_ui(LLMInfer) |
| LLMExport.build_ui(LLMExport) |
| LLMEval.build_ui(LLMEval) |
|
|
| concurrent = {} |
| if version.parse(gr.__version__) < version.parse('4.0.0'): |
| concurrent = {'concurrency_count': 5} |
| app.load( |
| partial(LLMTrain.update_input_model, arg_cls=RLHFArguments), |
| inputs=[LLMTrain.element('model')], |
| outputs=[LLMTrain.element('train_record')] + list(LLMTrain.valid_elements().values())) |
| app.load( |
| partial(LLMInfer.update_input_model, arg_cls=DeployArguments, has_record=False), |
| inputs=[LLMInfer.element('model')], |
| outputs=list(LLMInfer.valid_elements().values())) |
| app.load( |
| partial(LLMExport.update_input_model, arg_cls=ExportArguments, has_record=False), |
| inputs=[LLMExport.element('model')], |
| outputs=list(LLMExport.valid_elements().values())) |
| app.load( |
| partial(LLMEval.update_input_model, arg_cls=EvalArguments, has_record=False), |
| inputs=[LLMEval.element('model')], |
| outputs=list(LLMEval.valid_elements().values())) |
| app.queue(**concurrent).launch(server_name=server, inbrowser=True, server_port=port, height=800, share=share) |
|
|
|
|
| def webui_main(args: Union[List[str], WebUIArguments, None] = None): |
| return SwiftWebUI(args).main() |
|
|