Yuhao commited on
Commit
52a881a
·
1 Parent(s): f7f33b5

Restructure inference and add INT4 serving

Browse files
Files changed (47) hide show
  1. LICENSE +9 -0
  2. README.md +109 -56
  3. inference/.ipynb_checkpoints/deepseek_service-checkpoint.py +0 -384
  4. inference/.ipynb_checkpoints/demo-checkpoint.py +0 -76
  5. inference/.ipynb_checkpoints/inference-checkpoint.py +0 -43
  6. inference/.ipynb_checkpoints/model_utils-checkpoint.py +0 -120
  7. inference/README.md +11 -0
  8. inference/__init__.py +1 -0
  9. inference/__pycache__/app.cpython-311.pyc +0 -0
  10. inference/__pycache__/deepseek_service.cpython-311.pyc +0 -0
  11. inference/__pycache__/model_utils.cpython-311.pyc +0 -0
  12. inference/demo.py +0 -79
  13. inference/full_precision/__init__.py +1 -0
  14. inference/full_precision/__pycache__/app.cpython-311.pyc +0 -0
  15. inference/full_precision/__pycache__/chat.cpython-311.pyc +0 -0
  16. inference/full_precision/__pycache__/deepseek_service.cpython-311.pyc +0 -0
  17. inference/full_precision/__pycache__/demo.cpython-311.pyc +0 -0
  18. inference/full_precision/__pycache__/infer.cpython-311.pyc +0 -0
  19. inference/full_precision/__pycache__/model_utils.cpython-311.pyc +0 -0
  20. inference/{app.py → full_precision/app.py} +162 -256
  21. inference/{chat.py → full_precision/chat.py} +38 -35
  22. inference/{deepseek_service.py → full_precision/deepseek_service.py} +86 -199
  23. inference/full_precision/demo.py +41 -0
  24. inference/full_precision/infer.py +54 -0
  25. inference/{model_utils.py → full_precision/model_utils.py} +103 -57
  26. inference/full_precision/run_api.sh +6 -0
  27. inference/full_precision/run_chat.sh +6 -0
  28. inference/full_precision/run_infer.sh +6 -0
  29. inference/inference.py +0 -43
  30. inference/int4_quantized/__init__.py +1 -0
  31. inference/int4_quantized/__pycache__/app.cpython-311.pyc +0 -0
  32. inference/int4_quantized/__pycache__/chat.cpython-311.pyc +0 -0
  33. inference/int4_quantized/__pycache__/infer.cpython-311.pyc +0 -0
  34. inference/int4_quantized/__pycache__/model_utils.cpython-311.pyc +0 -0
  35. inference/{.ipynb_checkpoints/app-checkpoint.py → int4_quantized/app.py} +181 -260
  36. inference/{.ipynb_checkpoints/chat-checkpoint.py → int4_quantized/chat.py} +37 -38
  37. inference/int4_quantized/infer.py +82 -0
  38. inference/int4_quantized/model_utils.py +538 -0
  39. inference/int4_quantized/run_api.sh +6 -0
  40. inference/int4_quantized/run_chat.sh +6 -0
  41. inference/int4_quantized/run_infer.sh +6 -0
  42. inference/int4_quantized/test_single.sh +6 -0
  43. inference/temp_uploads/.ipynb_checkpoints/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef-checkpoint.jpg +0 -0
  44. inference/temp_uploads/.ipynb_checkpoints/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c-checkpoint.jpg +0 -0
  45. inference/temp_uploads/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef.jpg +0 -0
  46. inference/temp_uploads/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c.jpg +0 -0
  47. requirements.txt +5 -1
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ SkinGPT-R1 is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International license.
2
+
3
+ License summary:
4
+ - Attribution required
5
+ - Non-commercial use only
6
+ - Share adaptations under the same license
7
+
8
+ Full license text:
9
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
README.md CHANGED
@@ -10,94 +10,147 @@ tags:
10
 
11
  # SkinGPT-R1
12
 
13
- **SkinGPT-R1** is a dermatological reasoning vision Language model (VLM).
14
 
15
- ## ⚠️ Disclaimer
16
 
17
- This model is **for research and educational use only**. It is **NOT a substitute for professional medical advice, diagnosis, or treatment**.
18
 
19
- ## 🛠️ Environment Setup
20
 
21
- To ensure compatibility, we strongly recommend creating a fresh Conda environment.
22
 
23
- ### 1. Create Conda Environment
24
 
25
- Create a new environment named skingpt-r1 with Python 3.10:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  ```bash
28
  conda create -n skingpt-r1 python=3.10 -y
29
  conda activate skingpt-r1
 
30
  ```
31
 
32
- ### 2. Install Dependencies
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  ```bash
35
- pip install -r requirements.txt
36
  ```
37
 
38
- ### (Optional) For faster inference on NVIDIA GPUs:
39
 
40
  ```bash
41
- pip install flash-attn --no-build-isolation
42
  ```
43
 
44
- ## 🚀 Usage
45
 
46
- ### Quick Start
 
 
 
 
47
 
48
- If you just installed the environment and want to check if it works:
49
 
50
- Open ***demo.py*** and Change the ***IMAGE_PATH*** variable to your image file.
51
 
52
  ```bash
53
- python demo.py
54
  ```
55
 
56
- ### Interactive Chat
57
 
58
- To have a multi-turn conversation (e.g., asking follow-up questions about the diagnosis) in your terminal:
59
  ```bash
60
- python chat.py --image ./test_images/lesion.jpg
61
  ```
62
- ### FastAPI Backend Deployment
63
 
64
- To deploy the model as a backend service (supporting image uploads and session management):
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- #### Start the Server
 
 
 
 
67
 
68
  ```bash
69
- python app.py
70
  ```
71
- #### API Workflow
72
- Manage sessions via state_id to support multi-user history.
73
-
74
- Upload: POST /v1/upload/{state_id} — Uploads an image for the session.
75
-
76
- Chat: POST /v1/predict/{state_id} — Sends text (JSON: {"message": "..."}) and gets a response.
77
-
78
- Reset: POST /v1/reset/{state_id} — Clears session history and images.
79
- #### Client Example
80
- ```python
81
- import requests
82
-
83
- API_URL = "http://localhost:5900"
84
- STATE_ID = "patient_001"
85
-
86
- # 1. Upload Image
87
- with open("skin_image.jpg", "rb") as f:
88
- requests.post(f"{API_URL}/v1/upload/{STATE_ID}", files={"file": f})
89
-
90
- # 2. Ask for Diagnosis
91
- response = requests.post(
92
- f"{API_URL}/v1/predict/{STATE_ID}",
93
- json={"message": "Please analyze this image."}
94
- )
95
- print("AI:", response.json()["message"])
96
-
97
- # 3. Ask Follow-up
98
- response = requests.post(
99
- f"{API_URL}/v1/predict/{STATE_ID}",
100
- json={"message": "What treatment do you recommend?"}
101
- )
102
- print("AI:", response.json()["message"])
103
- ```
 
10
 
11
  # SkinGPT-R1
12
 
13
+ **Update:** We will soon release the **SkinGPT-R1-7B** weights.
14
 
15
+ SkinGPT-R1 is a dermatological reasoning vision language model for research and education.
16
 
17
+ From **The Chinese University of Hong Kong, Shenzhen (CUHKSZ)**.
18
 
19
+ ## Disclaimer
20
 
21
+ This project is for **research and educational use only**. It is **not** a substitute for professional medical advice, diagnosis, or treatment.
22
 
23
+ ## License
24
 
25
+ This repository is released under **CC BY-NC-SA 4.0**.
26
+ See [LICENSE](/Users/smac/Documents/SkinGPT-R1/LICENSE) for details.
27
+
28
+ ## Structure
29
+
30
+ ```text
31
+ SkinGPT-R1/
32
+ ├── checkpoints/
33
+ ├── inference/
34
+ │ ├── full_precision/
35
+ │ └── int4_quantized/
36
+ ├── requirements.txt
37
+ └── README.md
38
+ ```
39
+
40
+ Checkpoint paths:
41
+
42
+ - Full precision: `./checkpoints/full_precision`
43
+ - INT4 quantized: `./checkpoints/int4`
44
+
45
+ ## Install
46
 
47
  ```bash
48
  conda create -n skingpt-r1 python=3.10 -y
49
  conda activate skingpt-r1
50
+ pip install -r requirements.txt
51
  ```
52
 
53
+ ## Attention Backend Notes
54
+
55
+ This repo uses two attention acceleration paths:
56
+
57
+ - `flash_attention_2`: external package, optional
58
+ - `sdpa`: PyTorch native scaled dot product attention
59
+
60
+ Recommended choice for this repo:
61
+
62
+ - RTX 50 series: use `sdpa`
63
+ - A100 / RTX 3090 / RTX 4090 / H100 and other GPUs explicitly listed by the FlashAttention project: you can try `flash_attention_2`
64
+
65
+ Practical notes:
66
+
67
+ - The current repo pins `torch==2.4.0`, and SDPA is already built into PyTorch in this version.
68
+ - FlashAttention's official README currently lists Ampere, Ada, and Hopper support for FlashAttention-2. It does not list RTX 50 / Blackwell consumer GPUs in that section, so this repo defaults to `sdpa` for that path.
69
+ - PyTorch 2.5 added a newer cuDNN SDPA backend for H100-class or newer GPUs, but this repo is pinned to PyTorch 2.4, so you should not assume those 2.5-specific gains here.
70
+
71
+ If you are on an RTX 5090 and `flash-attn` is unavailable or unstable in your environment, use the INT4 path in this repo, which is already configured with `attn_implementation="sdpa"`.
72
+
73
+ ## Usage
74
+
75
+ ### Full Precision
76
+
77
+ Single image:
78
 
79
  ```bash
80
+ bash inference/full_precision/run_infer.sh --image ./test_images/lesion.jpg
81
  ```
82
 
83
+ Multi-turn chat:
84
 
85
  ```bash
86
+ bash inference/full_precision/run_chat.sh --image ./test_images/lesion.jpg
87
  ```
88
 
89
+ API service:
90
 
91
+ ```bash
92
+ bash inference/full_precision/run_api.sh
93
+ ```
94
+
95
+ Default API port: `5900`
96
 
97
+ ### INT4 Quantized
98
 
99
+ Single image:
100
 
101
  ```bash
102
+ bash inference/int4_quantized/run_infer.sh --image_path ./test_images/lesion.jpg
103
  ```
104
 
105
+ Multi-turn chat:
106
 
 
107
  ```bash
108
+ bash inference/int4_quantized/run_chat.sh --image ./test_images/lesion.jpg
109
  ```
 
110
 
111
+ API service:
112
+
113
+ ```bash
114
+ bash inference/int4_quantized/run_api.sh
115
+ ```
116
+
117
+ Default API port: `5901`
118
+
119
+ The INT4 path uses:
120
+
121
+ - `bitsandbytes` 4-bit quantization
122
+ - `attn_implementation="sdpa"`
123
+ - the adapter-aware quantized model implementation in `inference/int4_quantized/`
124
 
125
+ ## GPU Selection
126
+
127
+ You do not need to add `CUDA_VISIBLE_DEVICES=0` if the machine has only one visible GPU or if you are fine with the default CUDA device.
128
+
129
+ Use it only when you want to pin the process to a specific GPU, for example on a multi-GPU server:
130
 
131
  ```bash
132
+ CUDA_VISIBLE_DEVICES=0 bash inference/int4_quantized/run_infer.sh --image_path ./test_images/lesion.jpg
133
  ```
134
+
135
+ The same pattern also works for:
136
+
137
+ - `inference/full_precision/run_infer.sh`
138
+ - `inference/full_precision/run_chat.sh`
139
+ - `inference/full_precision/run_api.sh`
140
+ - `inference/int4_quantized/run_chat.sh`
141
+ - `inference/int4_quantized/run_api.sh`
142
+
143
+ ## API Endpoints
144
+
145
+ Both API services expose the same endpoints:
146
+
147
+ - `POST /v1/upload/{state_id}`
148
+ - `POST /v1/predict/{state_id}`
149
+ - `POST /v1/reset/{state_id}`
150
+ - `POST /diagnose/stream`
151
+ - `GET /health`
152
+
153
+ ## Which One To Use
154
+
155
+ - Use `full_precision` when you want the original model path and best fidelity.
156
+ - Use `int4_quantized` when GPU memory is tight or when you are on an environment where `flash-attn` is not the practical option.
 
 
 
 
 
 
 
 
 
 
inference/.ipynb_checkpoints/deepseek_service-checkpoint.py DELETED
@@ -1,384 +0,0 @@
1
- """
2
- DeepSeek API Service
3
- Used to optimize and organize SkinGPT model output results
4
- """
5
-
6
- import os
7
- import re
8
- from typing import Optional
9
- from openai import AsyncOpenAI
10
-
11
-
12
- class DeepSeekService:
13
- """DeepSeek API Service Class"""
14
-
15
- def __init__(self, api_key: Optional[str] = None):
16
- """
17
- Initialize DeepSeek service
18
-
19
- Parameters:
20
- api_key: DeepSeek API key, reads from environment variable if not provided
21
- """
22
- self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
23
- self.base_url = "https://api.deepseek.com"
24
- self.model = "deepseek-chat" # Using deepseek-chat model
25
-
26
- self.client = None
27
- self.is_loaded = False
28
-
29
- print(f"DeepSeek API service initializing...")
30
- print(f"API Base URL: {self.base_url}")
31
-
32
- async def load(self):
33
- """Initialize DeepSeek API client"""
34
- try:
35
- if not self.api_key:
36
- print("DeepSeek API key not provided")
37
- self.is_loaded = False
38
- return
39
-
40
- # Initialize OpenAI compatible client
41
- self.client = AsyncOpenAI(
42
- api_key=self.api_key,
43
- base_url=self.base_url
44
- )
45
-
46
- self.is_loaded = True
47
- print("DeepSeek API service is ready!")
48
-
49
- except Exception as e:
50
- print(f"DeepSeek API service initialization failed: {e}")
51
- self.is_loaded = False
52
-
53
- async def refine_diagnosis(
54
- self,
55
- raw_answer: str,
56
- raw_thinking: Optional[str] = None,
57
- language: str = "zh"
58
- ) -> dict:
59
- """
60
- Use DeepSeek API to optimize and organize diagnosis results
61
-
62
- Parameters:
63
- raw_answer: Original diagnosis result
64
- raw_thinking: AI thinking process
65
- language: Language option
66
-
67
- Returns:
68
- Dictionary containing "description", "analysis_process" and "diagnosis_result"
69
- """
70
-
71
- if not self.is_loaded or self.client is None:
72
- error_msg = "API not initialized, cannot generate analysis" if language == "en" else "API未初始化,无法生成分析过程"
73
- print("DeepSeek API not initialized, returning original result")
74
- return {
75
- "success": False,
76
- "description": "",
77
- "analysis_process": raw_thinking or error_msg,
78
- "diagnosis_result": raw_answer,
79
- "original_diagnosis": raw_answer,
80
- "error": "DeepSeek API not initialized"
81
- }
82
-
83
- try:
84
- # Build prompt
85
- prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
86
-
87
- # Select system prompt based on language
88
- if language == "en":
89
- system_content = "You are a professional medical text editor. Your task is to polish and organize medical diagnostic text to make it flow smoothly while preserving the original meaning. Output ONLY the formatted result. Do NOT add any explanations, comments, or thoughts. Just follow the format exactly."
90
- else:
91
- system_content = "你是医学文本整理专家,按照用户要求将用户输入的文本整理成用户想要的格式,不要改写或总结。"
92
-
93
- # Call DeepSeek API
94
- response = await self.client.chat.completions.create(
95
- model=self.model,
96
- messages=[
97
- {"role": "system", "content": system_content},
98
- {"role": "user", "content": prompt}
99
- ],
100
- temperature=0.1,
101
- max_tokens=2048,
102
- top_p=0.8,
103
- )
104
-
105
- # Extract generated text
106
- generated_text = response.choices[0].message.content
107
-
108
- # Parse output
109
- parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
110
-
111
- return {
112
- "success": True,
113
- "description": parsed["description"],
114
- "analysis_process": parsed["analysis_process"],
115
- "diagnosis_result": parsed["diagnosis_result"],
116
- "original_diagnosis": raw_answer,
117
- "raw_refined": generated_text
118
- }
119
-
120
- except Exception as e:
121
- print(f"DeepSeek API call failed: {e}")
122
- error_msg = "API call failed, cannot generate analysis" if language == "en" else "API调用失败,无法生成分析过程"
123
- return {
124
- "success": False,
125
- "description": "",
126
- "analysis_process": raw_thinking or error_msg,
127
- "diagnosis_result": raw_answer,
128
- "original_diagnosis": raw_answer,
129
- "error": str(e)
130
- }
131
-
132
- def _build_refine_prompt(self, raw_answer: str, raw_thinking: Optional[str] = None, language: str = "zh") -> str:
133
- """
134
- Build optimization prompt
135
-
136
- Parameters:
137
- raw_answer: Original diagnosis result
138
- raw_thinking: AI thinking process
139
- language: Language option, "zh" for Chinese, "en" for English
140
-
141
- Returns:
142
- Built prompt
143
- """
144
- if language == "en":
145
- # English prompt - organize and polish while preserving meaning
146
- thinking_text = raw_thinking if raw_thinking else "No analysis process available."
147
- prompt = f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
148
-
149
- 【Requirements】
150
- - Preserve the original tone and expression style
151
- - Text 1 contains the thinking process, Text 2 contains the diagnosis result
152
- - Extract the image observation part from the thinking process as Description. This should include all factual observations about what was seen in the image, not just a brief summary.
153
- - For Diagnostic Reasoning: refine and condense the remaining thinking content. Remove redundancies, self-doubt, circular reasoning, and unnecessary repetition. Keep it concise and not too long. Keep the logical chain clear and enhance readability. IMPORTANT: DO NOT include any image description or visual observations in Diagnostic Reasoning. Only include reasoning, analysis, and diagnostic thought process.
154
- - If [Text 1] content is NOT: No analysis process available. Then organize [Text 1] content accordingly, DO NOT confuse [Text 1] and [Text 2]
155
- - If [Text 1] content IS: No analysis process available. Then extract the analysis process and description from [Text 2]
156
- - DO NOT infer or add new medical information, DO NOT output any meta-commentary
157
- - You may adjust unreasonable statements or remove redundant content to improve clarity
158
-
159
- [Text 1]
160
- {thinking_text}
161
-
162
- [Text 2]
163
- {raw_answer}
164
-
165
- 【Output】Only output three sections, do not output anything else:
166
- ## Description
167
- (Extract all image observation content from the thinking process - include all factual descriptions of what was seen)
168
-
169
- ## Analysis Process
170
- (Refined and condensed diagnostic reasoning: remove self-doubt, circular logic, and redundancies. Keep it concise and not too long. Keep logical flow clear. Do NOT include image observations)
171
-
172
- ## Diagnosis Result
173
- (The organized diagnosis result from Text 2)
174
-
175
- 【Example】:
176
- ## Description
177
- The image shows red inflamed patches on the skin with pustules and darker colored spots. The lesions appear as papules and pustules distributed across the affected area, with some showing signs of inflammation and possible post-inflammatory hyperpigmentation.
178
-
179
- ## Analysis Process
180
- These findings are consistent with acne vulgaris, commonly seen during adolescence. The user's age aligns with typical onset for this condition. Treatment recommendations: over-the-counter medications such as benzoyl peroxide or topical antibiotics, avoiding picking at the skin, and consulting a dermatologist if severe. The goal is to control inflammation and prevent scarring.
181
-
182
- ## Diagnosis Result
183
- Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition, especially during adolescence, when hormonal changes cause overactive sebaceous glands, which can easily clog pores and form acne. Pathological care recommendations: 1. Keep face clean, wash face 2-3 times daily, use gentle cleansing products. 2. Avoid squeezing acne with hands to prevent worsening inflammation or leaving scars. 3. Avoid using irritating cosmetics and skincare products. 4. Can use topical medications containing salicylic acid, benzoyl peroxide, etc. 5. If necessary, can use oral antibiotics or other treatment methods under doctor's guidance. Precautions: 1. Avoid rubbing or damaging the affected area to prevent infection. 2. Eat less oily and spicy foods, eat more vegetables and fruits. 3. Maintain good rest habits, avoid staying up late. 4. If acne symptoms persist without improvement or show signs of worsening, seek medical attention promptly.
184
- """
185
- else:
186
- # Chinese prompt - translate to Simplified Chinese AND organize/polish
187
- thinking_text = raw_thinking if raw_thinking else "No analysis process available."
188
- prompt = f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
189
-
190
- 【要求】
191
- - 保留原文的语气和表达方式
192
- - 文本1是思考过程,文本2是诊断结果
193
- - 从思考过程中提取图像观察部分作为图像描述。需要包含所有关于图片中观察到的事实内容,不要简化或缩短。
194
- - 对于分析过程:提炼并精简剩余的思考内容,去除冗余、自我怀疑、兜圈子的内容。保持简洁,不要太长。保持逻辑链条清晰,增强可读性。重要:分析过程中不���包含任何图像描述或视觉观察内容,只包含推理、分析和诊断思考过程。
195
- - 如果【文本1】内容不是:No analysis process available.那么按要求整理【文本1】的内容,不要混淆【文本1】和【文本2】。
196
- - 如果【文本1】内容是:No analysis process available.那么从【文本2】提炼分析过程和描述。
197
- - 【文本1】和【文本2】需要翻译成简体中文
198
- - 禁止推断或添加新的医学信息,禁止输出任何元评论
199
- - 可以调整不合理的语句或去除冗余内容以提高清晰度
200
-
201
-
202
- 【文本1】
203
- {thinking_text}
204
-
205
- 【文本2】
206
- {raw_answer}
207
-
208
- 【输出】只输出三个部分,不要输出其他任何内容:
209
- ## 图像描述
210
- (从思考过程中提取所有图像观察内容,包含所有关于图片的事实描述)
211
-
212
- ## 分析过程
213
- (提炼并精简后的诊断推理:去除自我怀疑、兜圈逻辑和冗余内容。保持简洁,不要太长。保持逻辑流畅。不包含图像观察)
214
-
215
- ## 诊断结果
216
- (整理后的诊断结果)
217
-
218
- 【样例】:
219
- ## 图像描述
220
- 图片显示皮肤上有红色发炎的斑块,伴有脓疱和颜色较深的斑点。病变表现为分布在受影响区域的丘疹和脓疱,部分显示出炎症迹象和可能的炎症后色素沉着。
221
-
222
- ## 分析过程
223
- 这些表现符合寻常痤疮的特征,青春期常见。用户的年龄与该病症的典型发病年龄相符。治疗建议:使用非处方药物如过氧化苯甲酰或外用抗生素,避免抠抓皮肤,病情严重时咨询皮肤科医生。目标是控制炎症并防止疤痕形成。
224
-
225
- ## 诊断结果
226
- 可能的诊断:痤疮(青春痘) 解释:痤疮是一种常见的皮肤病,特别是在青少年期间,由于激素水平的变化导致皮脂腺过度活跃,容易堵塞毛孔,形成痤疮。 病理护理建议:1.保持面部清洁,每天洗脸2-3次,使用温和的洁面产品。 2.避免用手挤压痤疮,以免加重炎症或留下疤痕。 3.避免使用刺激性的化妆品和护肤品。 4.可以使用含有水杨酸、苯氧醇等成分的外用药物治疗。 5.如有需要,可以在医生指导下使用抗生素口服药或其他治疗方法。 注意事项:1. 避免摩擦或损伤患处,以免引起感染。 2. 饮食上应少吃油腻、辛辣食物,多吃蔬菜水果。 3. 保持良好的作息习惯,避免熬夜。 4. 如果痤疮症状持续不见好转或有恶化的趋势,应及时就医。
227
- """
228
-
229
- return prompt
230
-
231
- def _parse_refined_output(
232
- self,
233
- generated_text: str,
234
- raw_answer: str,
235
- raw_thinking: Optional[str] = None,
236
- language: str = "zh"
237
- ) -> dict:
238
- """
239
- Parse DeepSeek generated output
240
-
241
- Parameters:
242
- generated_text: DeepSeek generated text
243
- raw_answer: Original diagnosis (as fallback)
244
- raw_thinking: Original thinking process (as fallback)
245
- language: Language option
246
-
247
- Returns:
248
- Dictionary containing description, analysis_process and diagnosis_result
249
- """
250
- description = ""
251
- analysis_process = None
252
- diagnosis_result = None
253
-
254
- if language == "en":
255
- # English patterns
256
- desc_match = re.search(
257
- r'##\s*Description\s*\n([\s\S]*?)(?=##\s*Analysis\s*Process|$)',
258
- generated_text,
259
- re.IGNORECASE
260
- )
261
- analysis_match = re.search(
262
- r'##\s*Analysis\s*Process\s*\n([\s\S]*?)(?=##\s*Diagnosis\s*Result|$)',
263
- generated_text,
264
- re.IGNORECASE
265
- )
266
- result_match = re.search(
267
- r'##\s*Diagnosis\s*Result\s*\n([\s\S]*?)$',
268
- generated_text,
269
- re.IGNORECASE
270
- )
271
-
272
- desc_header = "## Description"
273
- analysis_header = "## Analysis Process"
274
- result_header = "## Diagnosis Result"
275
- else:
276
- # Chinese patterns
277
- desc_match = re.search(
278
- r'##\s*图像描述\s*\n([\s\S]*?)(?=##\s*分析过程|$)',
279
- generated_text
280
- )
281
- analysis_match = re.search(
282
- r'##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)',
283
- generated_text
284
- )
285
- result_match = re.search(
286
- r'##\s*诊断结果\s*\n([\s\S]*?)$',
287
- generated_text
288
- )
289
-
290
- desc_header = "## 图像描述"
291
- analysis_header = "## 分析过程"
292
- result_header = "## 诊断结果"
293
-
294
- # Extract description
295
- if desc_match:
296
- description = desc_match.group(1).strip()
297
- print(f"Successfully parsed description")
298
- else:
299
- print(f"Description parsing failed")
300
- description = ""
301
-
302
- # Extract analysis process
303
- if analysis_match:
304
- analysis_process = analysis_match.group(1).strip()
305
- print(f"Successfully parsed analysis process")
306
- else:
307
- print(f"Analysis process parsing failed, trying other methods")
308
- # Try to extract from generated text
309
- result_pos = generated_text.find(result_header)
310
- if result_pos > 0:
311
- # Get content before diagnosis result
312
- analysis_process = generated_text[:result_pos].strip()
313
- # Remove possible headers
314
- for header in [desc_header, analysis_header]:
315
- header_escaped = re.escape(header)
316
- analysis_process = re.sub(f'{header_escaped}\\s*\\n?', '', analysis_process).strip()
317
- else:
318
- # If no format at all, try to get first half
319
- mid_point = len(generated_text) // 2
320
- analysis_process = generated_text[:mid_point].strip()
321
-
322
- # If still empty, use original content (final fallback)
323
- if not analysis_process and raw_thinking:
324
- print(f"Using original raw_thinking as fallback")
325
- analysis_process = raw_thinking
326
-
327
- # Extract diagnosis result
328
- if result_match:
329
- diagnosis_result = result_match.group(1).strip()
330
- print(f"Successfully parsed diagnosis result")
331
- else:
332
- print(f"Diagnosis result parsing failed, trying other methods")
333
- # Try to extract from generated text
334
- result_pos = generated_text.find(result_header)
335
- if result_pos > 0:
336
- diagnosis_result = generated_text[result_pos:].strip()
337
- # Remove possible header
338
- result_header_escaped = re.escape(result_header)
339
- diagnosis_result = re.sub(f'^{result_header_escaped}\\s*\\n?', '', diagnosis_result).strip()
340
- else:
341
- # If no format at all, get second half
342
- mid_point = len(generated_text) // 2
343
- diagnosis_result = generated_text[mid_point:].strip()
344
-
345
- # If still empty, use original content (final fallback)
346
- if not diagnosis_result:
347
- print(f"Using original raw_answer as fallback")
348
- diagnosis_result = raw_answer
349
-
350
- return {
351
- "description": description,
352
- "analysis_process": analysis_process,
353
- "diagnosis_result": diagnosis_result
354
- }
355
-
356
-
357
- # Global DeepSeek service instance (lazy loading)
358
- _deepseek_service: Optional[DeepSeekService] = None
359
-
360
-
361
- async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
362
- """
363
- Get DeepSeek service instance (singleton pattern)
364
-
365
- Parameters:
366
- api_key: Optional API key to use
367
-
368
- Returns:
369
- DeepSeekService instance, or None if API initialization fails
370
- """
371
- global _deepseek_service
372
-
373
- if _deepseek_service is None:
374
- try:
375
- _deepseek_service = DeepSeekService(api_key=api_key)
376
- await _deepseek_service.load()
377
- if not _deepseek_service.is_loaded:
378
- print("DeepSeek API service initialization failed, will use fallback mode")
379
- return _deepseek_service # Return instance but marked as not loaded
380
- except Exception as e:
381
- print(f"DeepSeek service initialization failed: {e}")
382
- return None
383
-
384
- return _deepseek_service
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/.ipynb_checkpoints/demo-checkpoint.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
- from qwen_vl_utils import process_vision_info
4
- from PIL import Image
5
-
6
- # === Configuration ===
7
- MODEL_PATH = "../checkpoint"
8
- IMAGE_PATH = "test_image.jpg" # Please replace with your actual image path
9
- PROMPT = "You are a professional AI dermatology assistant. Please analyze this skin image and provide a diagnosis."
10
-
11
- def main():
12
- print(f"Loading model from {MODEL_PATH}...")
13
-
14
- # 1. Load Model
15
- try:
16
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
- MODEL_PATH,
18
- torch_dtype=torch.bfloat16,
19
- device_map="auto",
20
- trust_remote_code=True
21
- )
22
- processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
23
- except Exception as e:
24
- print(f"Error loading model: {e}")
25
- return
26
-
27
- # 2. Check Image
28
- import os
29
- if not os.path.exists(IMAGE_PATH):
30
- print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
31
- # Create a dummy image for code demonstration purposes if needed, or just return
32
- return
33
-
34
- # 3. Prepare Inputs
35
- messages = [
36
- {
37
- "role": "user",
38
- "content": [
39
- {"type": "image", "image": IMAGE_PATH},
40
- {"type": "text", "text": PROMPT},
41
- ],
42
- }
43
- ]
44
-
45
- print("Processing...")
46
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
47
- image_inputs, video_inputs = process_vision_info(messages)
48
-
49
- inputs = processor(
50
- text=[text],
51
- images=image_inputs,
52
- videos=video_inputs,
53
- padding=True,
54
- return_tensors="pt",
55
- ).to(model.device)
56
-
57
- # 4. Generate
58
- with torch.no_grad():
59
- generated_ids = model.generate(
60
- **inputs,
61
- max_new_tokens=1024,
62
- temperature=0.7,
63
- top_p=0.9
64
- )
65
-
66
- # 5. Decode
67
- output_text = processor.batch_decode(
68
- generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
69
- )
70
-
71
- print("\n=== Diagnosis Result ===")
72
- print(output_text[0])
73
- print("========================")
74
-
75
- if __name__ == "__main__":
76
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/.ipynb_checkpoints/inference-checkpoint.py DELETED
@@ -1,43 +0,0 @@
1
-
2
- import argparse
3
- from model_utils import SkinGPTModel
4
- import os
5
-
6
- def main():
7
- parser = argparse.ArgumentParser(description="SkinGPT-R1 Single Inference")
8
- parser.add_argument("--image", type=str, required=True, help="Path to the image")
9
- parser.add_argument("--model_path", type=str, default="../checkpoint")
10
- parser.add_argument("--prompt", type=str, default="Please analyze this skin image and provide a diagnosis.")
11
- args = parser.parse_args()
12
-
13
- if not os.path.exists(args.image):
14
- print(f"Error: Image not found at {args.image}")
15
- return
16
-
17
- # 1. 加载模型 (复用 model_utils)
18
- # 这样你就不用在这里重复写 transformers 的加载代码了
19
- bot = SkinGPTModel(args.model_path)
20
-
21
- # 2. 构造单轮消息
22
- system_prompt = "You are a professional AI dermatology assistant."
23
- messages = [
24
- {
25
- "role": "user",
26
- "content": [
27
- {"type": "image", "image": args.image},
28
- {"type": "text", "text": f"{system_prompt}\n\n{args.prompt}"}
29
- ]
30
- }
31
- ]
32
-
33
- # 3. 推理
34
- print(f"\nAnalyzing {args.image}...")
35
- response = bot.generate_response(messages)
36
-
37
- print("-" * 40)
38
- print("Result:")
39
- print(response)
40
- print("-" * 40)
41
-
42
- if __name__ == "__main__":
43
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/.ipynb_checkpoints/model_utils-checkpoint.py DELETED
@@ -1,120 +0,0 @@
1
- # model_utils.py
2
- import torch
3
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
- from qwen_vl_utils import process_vision_info
5
- from PIL import Image
6
- import os
7
- from threading import Thread
8
-
9
- class SkinGPTModel:
10
- def __init__(self, model_path, device=None):
11
- self.model_path = model_path
12
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
- print(f"Loading model from {model_path} on {self.device}...")
14
-
15
- self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
16
- model_path,
17
- torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
18
- attn_implementation="flash_attention_2" if self.device == "cuda" else None,
19
- device_map="auto" if self.device != "mps" else None,
20
- trust_remote_code=True
21
- )
22
-
23
- if self.device == "mps":
24
- self.model = self.model.to(self.device)
25
-
26
- self.processor = AutoProcessor.from_pretrained(
27
- model_path,
28
- trust_remote_code=True,
29
- min_pixels=256*28*28,
30
- max_pixels=1280*28*28
31
- )
32
- print("Model loaded successfully.")
33
-
34
- def generate_response(self, messages, max_new_tokens=1024, temperature=0.7):
35
- """
36
- 处理多轮对话的历史消息列表并生成回复
37
- messages format:
38
- [
39
- {'role': 'user', 'content': [{'type': 'image', 'image': 'path...'}, {'type': 'text', 'text': '...'}]},
40
- {'role': 'assistant', 'content': [{'type': 'text', 'text': '...'}]}
41
- ]
42
- """
43
- # 预处理文本模板
44
- text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
-
46
- # 预处理视觉信息
47
- image_inputs, video_inputs = process_vision_info(messages)
48
-
49
- inputs = self.processor(
50
- text=[text],
51
- images=image_inputs,
52
- videos=video_inputs,
53
- padding=True,
54
- return_tensors="pt",
55
- ).to(self.model.device)
56
-
57
- with torch.no_grad():
58
- generated_ids = self.model.generate(
59
- **inputs,
60
- max_new_tokens=max_new_tokens,
61
- temperature=temperature,
62
- top_p=0.9,
63
- do_sample=True
64
- )
65
-
66
- # 解码输出 (去除输入的token)
67
- generated_ids_trimmed = [
68
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
69
- ]
70
- output_text = self.processor.batch_decode(
71
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
72
- )
73
-
74
- return output_text[0]
75
-
76
- def generate_response_stream(self, messages, max_new_tokens=2048, temperature=0.7):
77
- """
78
- 流式生成响应
79
- 返回一个生成器,逐个yield生成的文本chunk
80
- """
81
- # 预处理文本模板
82
- text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
83
-
84
- # 预处理视觉信息
85
- image_inputs, video_inputs = process_vision_info(messages)
86
-
87
- inputs = self.processor(
88
- text=[text],
89
- images=image_inputs,
90
- videos=video_inputs,
91
- padding=True,
92
- return_tensors="pt",
93
- ).to(self.model.device)
94
-
95
- # 创建 TextIteratorStreamer 用于流式输出
96
- streamer = TextIteratorStreamer(
97
- self.processor.tokenizer,
98
- skip_prompt=True,
99
- skip_special_tokens=True
100
- )
101
-
102
- # 准备生成参数
103
- generation_kwargs = {
104
- **inputs,
105
- "max_new_tokens": max_new_tokens,
106
- "temperature": temperature,
107
- "top_p": 0.9,
108
- "do_sample": True,
109
- "streamer": streamer,
110
- }
111
-
112
- # 在单独的线程中运行生成
113
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
114
- thread.start()
115
-
116
- # 逐个yield生成的文本
117
- for text_chunk in streamer:
118
- yield text_chunk
119
-
120
- thread.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference
2
+
3
+ Two runtime tracks are provided:
4
+
5
+ - `full_precision/`: single-image inference, multi-turn chat, and FastAPI service
6
+ - `int4_quantized/`: single-image inference, multi-turn chat, and FastAPI service for the INT4 path
7
+
8
+ Checkpoint paths:
9
+
10
+ - `./checkpoints/full_precision`
11
+ - `./checkpoints/int4`
inference/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Inference entrypoints for SkinGPT-R1."""
inference/__pycache__/app.cpython-311.pyc DELETED
Binary file (17.8 kB)
 
inference/__pycache__/deepseek_service.cpython-311.pyc DELETED
Binary file (18.3 kB)
 
inference/__pycache__/model_utils.cpython-311.pyc DELETED
Binary file (5.39 kB)
 
inference/demo.py DELETED
@@ -1,79 +0,0 @@
1
- import torch
2
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
- from qwen_vl_utils import process_vision_info
4
- from PIL import Image
5
-
6
- # === Configuration ===
7
- MODEL_PATH = "../checkpoint"
8
- IMAGE_PATH = "test_image.jpg" # Please replace with your actual image path
9
- PROMPT = "You are a professional AI dermatology assistant. Please analyze this skin image and provide a diagnosis."
10
-
11
- def main():
12
- print(f"Loading model from {MODEL_PATH}...")
13
-
14
- # 1. Load Model
15
- try:
16
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
- MODEL_PATH,
18
- torch_dtype=torch.bfloat16,
19
- device_map="auto",
20
- trust_remote_code=True
21
- )
22
- processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
23
- except Exception as e:
24
- print(f"Error loading model: {e}")
25
- return
26
-
27
- # 2. Check Image
28
- import os
29
- if not os.path.exists(IMAGE_PATH):
30
- print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
31
- # Create a dummy image for code demonstration purposes if needed, or just return
32
- return
33
-
34
- # 3. Prepare Inputs
35
- messages = [
36
- {
37
- "role": "user",
38
- "content": [
39
- {"type": "image", "image": IMAGE_PATH},
40
- {"type": "text", "text": PROMPT},
41
- ],
42
- }
43
- ]
44
-
45
- print("Processing...")
46
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
47
- image_inputs, video_inputs = process_vision_info(messages)
48
-
49
- inputs = processor(
50
- text=[text],
51
- images=image_inputs,
52
- videos=video_inputs,
53
- padding=True,
54
- return_tensors="pt",
55
- ).to(model.device)
56
-
57
- # 4. Generate
58
- with torch.no_grad():
59
- generated_ids = model.generate(
60
- **inputs,
61
- max_new_tokens=1024,
62
- temperature=0.7,
63
- repetition_penalty=1.2,
64
- no_repeat_ngram_size=3,
65
- top_p=0.9,
66
- do_sample=True
67
- )
68
-
69
- # 5. Decode
70
- output_text = processor.batch_decode(
71
- generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
72
- )
73
-
74
- print("\n=== Diagnosis Result ===")
75
- print(output_text[0])
76
- print("========================")
77
-
78
- if __name__ == "__main__":
79
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/full_precision/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Full-precision inference package for SkinGPT-R1."""
inference/full_precision/__pycache__/app.cpython-311.pyc ADDED
Binary file (17.4 kB). View file
 
inference/full_precision/__pycache__/chat.cpython-311.pyc ADDED
Binary file (3.58 kB). View file
 
inference/full_precision/__pycache__/deepseek_service.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
inference/full_precision/__pycache__/demo.cpython-311.pyc ADDED
Binary file (1.8 kB). View file
 
inference/full_precision/__pycache__/infer.cpython-311.pyc ADDED
Binary file (2.63 kB). View file
 
inference/full_precision/__pycache__/model_utils.cpython-311.pyc ADDED
Binary file (7.37 kB). View file
 
inference/{app.py → full_precision/app.py} RENAMED
@@ -1,133 +1,81 @@
1
- # app.py
2
- import uvicorn
 
 
3
  import os
4
  import shutil
5
  import uuid
6
- import json
7
- import re
8
- import asyncio
9
- from typing import Optional
10
- from io import BytesIO
11
  from contextlib import asynccontextmanager
12
- from PIL import Image
13
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
 
 
 
 
 
 
 
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from fastapi.responses import StreamingResponse
16
- from fastapi.concurrency import run_in_threadpool
17
- from model_utils import SkinGPTModel
18
- from deepseek_service import get_deepseek_service, DeepSeekService
19
 
20
- # === Configuration ===
21
- MODEL_PATH = "../checkpoint"
22
- TEMP_DIR = "./temp_uploads"
23
- os.makedirs(TEMP_DIR, exist_ok=True)
 
 
24
 
25
- # DeepSeek API Key
26
- DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "sk-b221f29be052460f9e0fe12d88dd343c")
 
 
27
 
28
- # Global DeepSeek service instance
29
  deepseek_service: Optional[DeepSeekService] = None
30
 
31
- @asynccontextmanager
32
- async def lifespan(app: FastAPI):
33
- """应用生命周期管理"""
34
- # 启动时初始化 DeepSeek 服务
35
- await init_deepseek()
36
- yield
37
- print("\nShutting down service...")
38
 
39
- app = FastAPI(
40
- title="SkinGPT-R1 皮肤诊断系统",
41
- description="智能皮肤诊断助手",
42
- version="1.0.0",
43
- lifespan=lifespan
44
- )
45
 
46
- # CORS配置 - 允许前端访问
47
- app.add_middleware(
48
- CORSMiddleware,
49
- allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
50
- allow_credentials=True,
51
- allow_methods=["*"],
52
- allow_headers=["*"],
53
- )
54
 
55
- # 全局变量存储状态
56
- # chat_states: 存储对话历史 (List of messages for Qwen)
57
- # pending_images: 存储已上传但尚未发送给LLM的图片路径 (State ID -> Image Path)
58
- chat_states = {}
59
- pending_images = {}
60
 
61
- def parse_diagnosis_result(raw_text: str) -> dict:
62
- """
63
- 解析诊断结果中的think和answer标签
64
-
65
- 参数:
66
- - raw_text: 原始诊断文本
67
-
68
- 返回:
69
- - dict: 包含thinking, answer, raw字段的字典
70
- """
71
- import re
72
-
73
- # 尝试匹配完整的标签
74
- think_match = re.search(r'<think>([\s\S]*?)</think>', raw_text)
75
- answer_match = re.search(r'<answer>([\s\S]*?)</answer>', raw_text)
76
-
77
- thinking = None
78
- answer = None
79
-
80
- # 处理think标签
81
- if think_match:
82
- thinking = think_match.group(1).strip()
83
- else:
84
- # 尝试匹配未闭合的think标签(输出被截断的情况)
85
- unclosed_think = re.search(r'<think>([\s\S]*?)(?=<answer>|$)', raw_text)
86
  if unclosed_think:
87
  thinking = unclosed_think.group(1).strip()
88
-
89
- # 处理answer标签
90
- if answer_match:
91
- answer = answer_match.group(1).strip()
92
- else:
93
- # 尝试匹配未闭合的answer标签
94
- unclosed_answer = re.search(r'<answer>([\s\S]*?)$', raw_text)
95
  if unclosed_answer:
96
  answer = unclosed_answer.group(1).strip()
97
-
98
- # 如果仍然没有找到answer,清理原始文本作为answer
99
  if not answer:
100
- # 移除所有标签及其内容
101
- cleaned = re.sub(r'<think>[\s\S]*?</think>', '', raw_text)
102
- cleaned = re.sub(r'<think>[\s\S]*', '', cleaned) # 移除未闭合的think
103
- cleaned = re.sub(r'</?answer>', '', cleaned) # 移除answer标签
104
- cleaned = cleaned.strip()
105
- answer = cleaned if cleaned else raw_text
106
-
107
- # 清理可能残留的标签
108
- if answer:
109
- answer = re.sub(r'</?think>|</?answer>', '', answer).strip()
110
- if thinking:
111
- thinking = re.sub(r'</?think>|</?answer>', '', thinking).strip()
112
-
113
- # 处理 "Final Answer:" 格式,提取其后的内容
114
  if answer:
115
- final_answer_match = re.search(r'Final Answer:\s*([\s\S]*)', answer, re.IGNORECASE)
 
116
  if final_answer_match:
117
  answer = final_answer_match.group(1).strip()
118
-
119
- return {
120
- "thinking": thinking if thinking else None,
121
- "answer": answer,
122
- "raw": raw_text
123
- }
124
 
125
  print("Initializing Model Service...")
126
- # 全局加载模型
127
  gpt_model = SkinGPTModel(MODEL_PATH)
128
  print("Service Ready.")
129
 
130
- # 初始化 DeepSeek 服务(异步)
131
  async def init_deepseek():
132
  global deepseek_service
133
  print("\nInitializing DeepSeek service...")
@@ -137,120 +85,116 @@ async def init_deepseek():
137
  else:
138
  print("DeepSeek service not available, will return raw results")
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  @app.post("/v1/upload/{state_id}")
141
  async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
142
- """
143
- 接收图片上传。
144
- 逻辑:将图片保存到本地临时目录,并标记该 state_id 有一张待处理图片。
145
- """
146
  try:
147
- # 1. 保存图片到本地临时文件
148
  file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
149
  unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
150
- file_path = os.path.join(TEMP_DIR, unique_name)
151
-
152
- with open(file_path, "wb") as buffer:
153
  shutil.copyfileobj(file.file, buffer)
154
-
155
- # 2. 记录图片路径等待下一次 predict 调用时使用
156
- # 如果是多图模式,这里可以改成 list,目前演示单图覆盖或更新
157
- pending_images[state_id] = file_path
158
-
159
- # 3. 初始化对话状态(如果是新会话)
160
  if state_id not in chat_states:
161
  chat_states[state_id] = []
162
-
163
- return {"message": "Image uploaded successfully", "path": file_path}
164
-
165
- except Exception as e:
166
- raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
167
 
168
  @app.post("/v1/predict/{state_id}")
169
  async def v1_predict(request: Request, state_id: str):
170
- """
171
- 接收文本并执行推理。
172
- 逻辑:检查是否有待处理图片。如果有,将其与文本组合成 multimodal 消息。
173
- """
174
  try:
175
  data = await request.json()
176
- except:
177
- raise HTTPException(status_code=400, detail="Invalid JSON")
178
-
179
  user_message = data.get("message", "")
180
  if not user_message:
181
  raise HTTPException(status_code=400, detail="Missing 'message' field")
182
 
183
- # 获取或初始化历史
184
  history = chat_states.get(state_id, [])
185
-
186
- # 构建当前轮次的用户内容
187
  current_content = []
188
-
189
- # 1. 检查是否有刚刚上传的图片
190
  if state_id in pending_images:
191
- img_path = pending_images.pop(state_id) # 取出并移除
192
  current_content.append({"type": "image", "image": img_path})
193
-
194
- # 如果是第一次对话,加上 System Prompt
195
  if not history:
196
- system_prompt = "You are a professional AI dermatology assistant. "
197
- user_message = f"{system_prompt}\n\n{user_message}"
198
 
199
- # 2. 添加文本
200
  current_content.append({"type": "text", "text": user_message})
201
-
202
- # 3. 更新历史
203
  history.append({"role": "user", "content": current_content})
204
  chat_states[state_id] = history
205
 
206
- # 4. 运行推理 (在线程池中运行以防阻塞)
207
  try:
208
- response_text = await run_in_threadpool(
209
- gpt_model.generate_response,
210
- messages=history
211
- )
212
- except Exception as e:
213
- # 回滚历史(移除刚才出错的用户提问)
214
  chat_states[state_id].pop()
215
- raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
216
 
217
- # 5. 将回复加入历史
218
  history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
219
  chat_states[state_id] = history
220
-
221
  return {"message": response_text}
222
 
 
223
  @app.post("/v1/reset/{state_id}")
224
  async def reset_chat(state_id: str):
225
- """清除会话状态"""
226
  if state_id in chat_states:
227
  del chat_states[state_id]
228
  if state_id in pending_images:
229
- # 可选:删除临时文件
230
  try:
231
- os.remove(pending_images[state_id])
232
- except:
233
  pass
234
  del pending_images[state_id]
235
  return {"message": "Chat history reset"}
236
 
 
237
  @app.get("/")
238
  async def root():
239
- """根路径"""
240
  return {
241
- "name": "SkinGPT-R1 皮肤诊断系统",
242
- "version": "1.0.0",
243
  "status": "running",
244
- "description": "智能皮肤诊断助手"
245
  }
246
 
 
247
  @app.get("/health")
248
  async def health_check():
249
- """健康检查"""
250
- return {
251
- "status": "healthy",
252
- "model_loaded": True
253
- }
254
 
255
  @app.post("/diagnose/stream")
256
  async def diagnose_stream(
@@ -258,126 +202,89 @@ async def diagnose_stream(
258
  text: str = Form(...),
259
  language: str = Form("zh"),
260
  ):
261
- """
262
- SSE流式诊断接口(用于前端)
263
- 支持图片上传和文本输入,返回真正的流式响应
264
- 使用 DeepSeek API 优化输出格式
265
- """
266
- from queue import Queue, Empty
267
- from threading import Thread
268
-
269
  language = language if language in ("zh", "en") else "zh"
270
-
271
- # 处理图片
272
  pil_image = None
273
- temp_image_path = None
274
-
275
  if image:
276
  contents = await image.read()
277
  pil_image = Image.open(BytesIO(contents)).convert("RGB")
278
-
279
- # 创建队列用于线程间通信
280
  result_queue = Queue()
281
- # 用于存储完整响应和解析结果
282
  generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
283
-
284
  def run_generation():
285
- """在后台线程中运行流式生成"""
286
  full_response = []
287
-
288
  try:
289
- # 构建消息
290
  messages = []
291
  current_content = []
292
-
293
- # 添加系统提示
294
- system_prompt = "You are a professional AI dermatology assistant." if language == "en" else "你是一个专业的AI皮肤科助手。"
295
-
296
- # 如果有图片,保存到临时文件
 
297
  if pil_image:
298
- generation_result["temp_image_path"] = os.path.join(TEMP_DIR, f"temp_{uuid.uuid4().hex}.jpg")
299
- pil_image.save(generation_result["temp_image_path"])
300
- current_content.append({"type": "image", "image": generation_result["temp_image_path"]})
301
-
302
- # 添加文本
303
- prompt = f"{system_prompt}\n\n{text}"
304
- current_content.append({"type": "text", "text": prompt})
305
  messages.append({"role": "user", "content": current_content})
306
-
307
- # 流式生成 - 每个 chunk 立即放入队列
308
  for chunk in gpt_model.generate_response_stream(
309
  messages=messages,
310
  max_new_tokens=2048,
311
- temperature=0.7
312
  ):
313
  full_response.append(chunk)
314
  result_queue.put(("delta", chunk))
315
-
316
- # 解析结果
317
  response_text = "".join(full_response)
318
- parsed = parse_diagnosis_result(response_text)
319
  generation_result["full_response"] = full_response
320
- generation_result["parsed"] = parsed
321
-
322
- # 标记生成完成
323
  result_queue.put(("generation_done", None))
324
-
325
- except Exception as e:
326
- result_queue.put(("error", str(e)))
327
-
328
  async def event_generator():
329
- """异步生成SSE事件"""
330
- # 在后台线程启动生成(非阻塞)
331
  gen_thread = Thread(target=run_generation)
332
  gen_thread.start()
333
-
334
  loop = asyncio.get_event_loop()
335
-
336
- # 从队列中读取并发送流式内容
337
  while True:
338
  try:
339
- # 非阻塞获取
340
  msg_type, data = await loop.run_in_executor(
341
- None,
342
- lambda: result_queue.get(timeout=0.1)
343
  )
344
-
345
  if msg_type == "generation_done":
346
- # 流式生成完成,准备处理最终结果
347
  break
348
- elif msg_type == "delta":
349
- yield_chunk = json.dumps({"type": "delta", "text": data}, ensure_ascii=False)
350
- yield f"data: {yield_chunk}\n\n"
351
  elif msg_type == "error":
352
  yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
353
  gen_thread.join()
354
  return
355
-
356
  except Empty:
357
- # 队列暂时为空,继续等待
358
  await asyncio.sleep(0.01)
359
- continue
360
-
361
  gen_thread.join()
362
-
363
- # 获取解析结果
364
  parsed = generation_result["parsed"]
365
  if not parsed:
366
- yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse response'}, ensure_ascii=False)}\n\n"
367
  return
368
-
369
  raw_thinking = parsed["thinking"]
370
  raw_answer = parsed["answer"]
371
-
372
- # 使用 DeepSeek 优化结果
373
  refined_by_deepseek = False
374
  description = None
375
  thinking = raw_thinking
376
  answer = raw_answer
377
-
378
  if deepseek_service and deepseek_service.is_loaded:
379
  try:
380
- print(f"Calling DeepSeek to refine diagnosis (language={language})...")
381
  refined = await deepseek_service.refine_diagnosis(
382
  raw_answer=raw_answer,
383
  raw_thinking=raw_thinking,
@@ -388,36 +295,35 @@ async def diagnose_stream(
388
  thinking = refined["analysis_process"]
389
  answer = refined["diagnosis_result"]
390
  refined_by_deepseek = True
391
- print(f"DeepSeek refinement completed successfully")
392
- except Exception as e:
393
- print(f"DeepSeek refinement failed, using original: {e}")
394
  else:
395
  print("DeepSeek service not available, using raw results")
396
-
397
- success_msg = "Diagnosis completed" if language == "en" else "诊断完成"
398
-
399
- # 返回格式与参考项目保持一致
400
  final_payload = {
401
- "description": description, # 图片描述(从 thinking 中提取)
402
- "thinking": thinking, # 分析过程(DeepSeek 优化后)
403
- "answer": answer, # 诊断结果(DeepSeek 优化后)
404
- "raw": parsed["raw"], # 原始响应
405
- "refined_by_deepseek": refined_by_deepseek, # 是否被 DeepSeek 优化
406
  "success": True,
407
- "message": success_msg
408
  }
409
- yield_final = json.dumps({"type": "final", "result": final_payload}, ensure_ascii=False)
410
- yield f"data: {yield_final}\n\n"
411
-
412
- # 清理临时图片
413
  temp_path = generation_result.get("temp_image_path")
414
- if temp_path and os.path.exists(temp_path):
415
  try:
416
- os.remove(temp_path)
417
- except:
418
  pass
419
-
420
  return StreamingResponse(event_generator(), media_type="text/event-stream")
421
 
422
- if __name__ == '__main__':
423
- uvicorn.run("app:app", host="0.0.0.0", port=5900, reload=False)
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
  import os
6
  import shutil
7
  import uuid
 
 
 
 
 
8
  from contextlib import asynccontextmanager
9
+ from io import BytesIO
10
+ from pathlib import Path
11
+ from queue import Empty, Queue
12
+ from threading import Thread
13
+ from typing import Optional
14
+
15
+ import uvicorn
16
+ from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
17
+ from fastapi.concurrency import run_in_threadpool
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from fastapi.responses import StreamingResponse
20
+ from PIL import Image
 
 
21
 
22
+ try:
23
+ from .deepseek_service import DeepSeekService, get_deepseek_service
24
+ from .model_utils import DEFAULT_MODEL_PATH, SkinGPTModel, resolve_model_path
25
+ except ImportError:
26
+ from deepseek_service import DeepSeekService, get_deepseek_service
27
+ from model_utils import DEFAULT_MODEL_PATH, SkinGPTModel, resolve_model_path
28
 
29
+ MODEL_PATH = resolve_model_path(DEFAULT_MODEL_PATH)
30
+ TEMP_DIR = Path(__file__).resolve().parents[1] / "temp_uploads"
31
+ TEMP_DIR.mkdir(parents=True, exist_ok=True)
32
+ DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
33
 
 
34
  deepseek_service: Optional[DeepSeekService] = None
35
 
 
 
 
 
 
 
 
36
 
37
+ def parse_diagnosis_result(raw_text: str) -> dict:
38
+ import re
 
 
 
 
39
 
40
+ think_match = re.search(r"<think>([\s\S]*?)</think>", raw_text)
41
+ answer_match = re.search(r"<answer>([\s\S]*?)</answer>", raw_text)
 
 
 
 
 
 
42
 
43
+ thinking = think_match.group(1).strip() if think_match else None
44
+ answer = answer_match.group(1).strip() if answer_match else None
 
 
 
45
 
46
+ if not thinking:
47
+ unclosed_think = re.search(r"<think>([\s\S]*?)(?=<answer>|$)", raw_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  if unclosed_think:
49
  thinking = unclosed_think.group(1).strip()
50
+
51
+ if not answer:
52
+ unclosed_answer = re.search(r"<answer>([\s\S]*?)$", raw_text)
 
 
 
 
53
  if unclosed_answer:
54
  answer = unclosed_answer.group(1).strip()
55
+
 
56
  if not answer:
57
+ cleaned = re.sub(r"<think>[\s\S]*?</think>", "", raw_text)
58
+ cleaned = re.sub(r"<think>[\s\S]*", "", cleaned)
59
+ cleaned = re.sub(r"</?answer>", "", cleaned)
60
+ answer = cleaned.strip() or raw_text
61
+
 
 
 
 
 
 
 
 
 
62
  if answer:
63
+ answer = re.sub(r"</?think>|</?answer>", "", answer).strip()
64
+ final_answer_match = re.search(r"Final Answer:\s*([\s\S]*)", answer, re.IGNORECASE)
65
  if final_answer_match:
66
  answer = final_answer_match.group(1).strip()
67
+
68
+ if thinking:
69
+ thinking = re.sub(r"</?think>|</?answer>", "", thinking).strip()
70
+
71
+ return {"thinking": thinking or None, "answer": answer, "raw": raw_text}
72
+
73
 
74
  print("Initializing Model Service...")
 
75
  gpt_model = SkinGPTModel(MODEL_PATH)
76
  print("Service Ready.")
77
 
78
+
79
  async def init_deepseek():
80
  global deepseek_service
81
  print("\nInitializing DeepSeek service...")
 
85
  else:
86
  print("DeepSeek service not available, will return raw results")
87
 
88
+
89
+ @asynccontextmanager
90
+ async def lifespan(app: FastAPI):
91
+ await init_deepseek()
92
+ yield
93
+ print("\nShutting down service...")
94
+
95
+
96
+ app = FastAPI(
97
+ title="SkinGPT-R1 Full Precision API",
98
+ description="Full-precision dermatology assistant backend",
99
+ version="1.1.0",
100
+ lifespan=lifespan,
101
+ )
102
+
103
+ app.add_middleware(
104
+ CORSMiddleware,
105
+ allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
106
+ allow_credentials=True,
107
+ allow_methods=["*"],
108
+ allow_headers=["*"],
109
+ )
110
+
111
+ chat_states = {}
112
+ pending_images = {}
113
+
114
+
115
  @app.post("/v1/upload/{state_id}")
116
  async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
117
+ del survey
 
 
 
118
  try:
 
119
  file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
120
  unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
121
+ file_path = TEMP_DIR / unique_name
122
+
123
+ with file_path.open("wb") as buffer:
124
  shutil.copyfileobj(file.file, buffer)
125
+
126
+ pending_images[state_id] = str(file_path)
127
+
 
 
 
128
  if state_id not in chat_states:
129
  chat_states[state_id] = []
130
+
131
+ return {"message": "Image uploaded successfully", "path": str(file_path)}
132
+ except Exception as exc:
133
+ raise HTTPException(status_code=500, detail=f"Upload failed: {exc}") from exc
134
+
135
 
136
  @app.post("/v1/predict/{state_id}")
137
  async def v1_predict(request: Request, state_id: str):
 
 
 
 
138
  try:
139
  data = await request.json()
140
+ except Exception as exc:
141
+ raise HTTPException(status_code=400, detail="Invalid JSON") from exc
142
+
143
  user_message = data.get("message", "")
144
  if not user_message:
145
  raise HTTPException(status_code=400, detail="Missing 'message' field")
146
 
 
147
  history = chat_states.get(state_id, [])
 
 
148
  current_content = []
149
+
 
150
  if state_id in pending_images:
151
+ img_path = pending_images.pop(state_id)
152
  current_content.append({"type": "image", "image": img_path})
 
 
153
  if not history:
154
+ user_message = f"You are a professional AI dermatology assistant.\n\n{user_message}"
 
155
 
 
156
  current_content.append({"type": "text", "text": user_message})
 
 
157
  history.append({"role": "user", "content": current_content})
158
  chat_states[state_id] = history
159
 
 
160
  try:
161
+ response_text = await run_in_threadpool(gpt_model.generate_response, messages=history)
162
+ except Exception as exc:
 
 
 
 
163
  chat_states[state_id].pop()
164
+ raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc
165
 
 
166
  history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
167
  chat_states[state_id] = history
 
168
  return {"message": response_text}
169
 
170
+
171
  @app.post("/v1/reset/{state_id}")
172
  async def reset_chat(state_id: str):
 
173
  if state_id in chat_states:
174
  del chat_states[state_id]
175
  if state_id in pending_images:
 
176
  try:
177
+ Path(pending_images[state_id]).unlink(missing_ok=True)
178
+ except Exception:
179
  pass
180
  del pending_images[state_id]
181
  return {"message": "Chat history reset"}
182
 
183
+
184
  @app.get("/")
185
  async def root():
 
186
  return {
187
+ "name": "SkinGPT-R1 Full Precision API",
188
+ "version": "1.1.0",
189
  "status": "running",
190
+ "description": "Full-precision dermatology assistant",
191
  }
192
 
193
+
194
  @app.get("/health")
195
  async def health_check():
196
+ return {"status": "healthy", "model_loaded": True}
197
+
 
 
 
198
 
199
  @app.post("/diagnose/stream")
200
  async def diagnose_stream(
 
202
  text: str = Form(...),
203
  language: str = Form("zh"),
204
  ):
 
 
 
 
 
 
 
 
205
  language = language if language in ("zh", "en") else "zh"
 
 
206
  pil_image = None
207
+
 
208
  if image:
209
  contents = await image.read()
210
  pil_image = Image.open(BytesIO(contents)).convert("RGB")
211
+
 
212
  result_queue = Queue()
 
213
  generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
214
+
215
  def run_generation():
 
216
  full_response = []
 
217
  try:
 
218
  messages = []
219
  current_content = []
220
+ system_prompt = (
221
+ "You are a professional AI dermatology assistant."
222
+ if language == "en"
223
+ else "你是一个专业的AI皮肤科助手。"
224
+ )
225
+
226
  if pil_image:
227
+ temp_image_path = TEMP_DIR / f"temp_{uuid.uuid4().hex}.jpg"
228
+ pil_image.save(temp_image_path)
229
+ generation_result["temp_image_path"] = str(temp_image_path)
230
+ current_content.append({"type": "image", "image": str(temp_image_path)})
231
+
232
+ current_content.append({"type": "text", "text": f"{system_prompt}\n\n{text}"})
 
233
  messages.append({"role": "user", "content": current_content})
234
+
 
235
  for chunk in gpt_model.generate_response_stream(
236
  messages=messages,
237
  max_new_tokens=2048,
238
+ temperature=0.7,
239
  ):
240
  full_response.append(chunk)
241
  result_queue.put(("delta", chunk))
242
+
 
243
  response_text = "".join(full_response)
 
244
  generation_result["full_response"] = full_response
245
+ generation_result["parsed"] = parse_diagnosis_result(response_text)
 
 
246
  result_queue.put(("generation_done", None))
247
+ except Exception as exc:
248
+ result_queue.put(("error", str(exc)))
249
+
 
250
  async def event_generator():
 
 
251
  gen_thread = Thread(target=run_generation)
252
  gen_thread.start()
253
+
254
  loop = asyncio.get_event_loop()
 
 
255
  while True:
256
  try:
 
257
  msg_type, data = await loop.run_in_executor(
258
+ None,
259
+ lambda: result_queue.get(timeout=0.1),
260
  )
 
261
  if msg_type == "generation_done":
 
262
  break
263
+ if msg_type == "delta":
264
+ yield f"data: {json.dumps({'type': 'delta', 'text': data}, ensure_ascii=False)}\n\n"
 
265
  elif msg_type == "error":
266
  yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
267
  gen_thread.join()
268
  return
 
269
  except Empty:
 
270
  await asyncio.sleep(0.01)
271
+
 
272
  gen_thread.join()
273
+
 
274
  parsed = generation_result["parsed"]
275
  if not parsed:
276
+ yield "data: {\"type\": \"error\", \"message\": \"Failed to parse response\"}\n\n"
277
  return
278
+
279
  raw_thinking = parsed["thinking"]
280
  raw_answer = parsed["answer"]
 
 
281
  refined_by_deepseek = False
282
  description = None
283
  thinking = raw_thinking
284
  answer = raw_answer
285
+
286
  if deepseek_service and deepseek_service.is_loaded:
287
  try:
 
288
  refined = await deepseek_service.refine_diagnosis(
289
  raw_answer=raw_answer,
290
  raw_thinking=raw_thinking,
 
295
  thinking = refined["analysis_process"]
296
  answer = refined["diagnosis_result"]
297
  refined_by_deepseek = True
298
+ except Exception as exc:
299
+ print(f"DeepSeek refinement failed, using original: {exc}")
 
300
  else:
301
  print("DeepSeek service not available, using raw results")
302
+
 
 
 
303
  final_payload = {
304
+ "description": description,
305
+ "thinking": thinking,
306
+ "answer": answer,
307
+ "raw": parsed["raw"],
308
+ "refined_by_deepseek": refined_by_deepseek,
309
  "success": True,
310
+ "message": "Diagnosis completed" if language == "en" else "诊断完成",
311
  }
312
+ yield f"data: {json.dumps({'type': 'final', 'result': final_payload}, ensure_ascii=False)}\n\n"
313
+
 
 
314
  temp_path = generation_result.get("temp_image_path")
315
+ if temp_path:
316
  try:
317
+ Path(temp_path).unlink(missing_ok=True)
318
+ except Exception:
319
  pass
320
+
321
  return StreamingResponse(event_generator(), media_type="text/event-stream")
322
 
323
+
324
+ def main() -> None:
325
+ uvicorn.run("app:app", host="0.0.0.0", port=5900, reload=False)
326
+
327
+
328
+ if __name__ == "__main__":
329
+ main()
inference/{chat.py → full_precision/chat.py} RENAMED
@@ -1,48 +1,53 @@
1
- # chat.py
 
2
  import argparse
3
- import os
4
- from model_utils import SkinGPTModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def main():
7
- parser = argparse.ArgumentParser(description="SkinGPT-R1 Multi-turn Chat")
8
- parser.add_argument("--model_path", type=str, default="../checkpoint")
 
9
  parser.add_argument("--image", type=str, required=True, help="Path to initial image")
10
- args = parser.parse_args()
 
11
 
12
- # 初始化模型
13
- bot = SkinGPTModel(args.model_path)
14
 
15
- # 初始化对话历史
16
- # 系统提示词
17
- system_prompt = "You are a professional AI dermatology assistant. Analyze the skin condition carefully."
18
-
19
- # 构造第一条包含图片的消息
20
- if not os.path.exists(args.image):
21
  print(f"Error: Image {args.image} not found.")
22
  return
23
 
24
- history = [
25
- {
26
- "role": "user",
27
- "content": [
28
- {"type": "image", "image": args.image},
29
- {"type": "text", "text": f"{system_prompt}\n\nPlease analyze this image."}
30
- ]
31
- }
32
- ]
33
 
34
  print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
35
  print(f"Image loaded: {args.image}")
36
-
37
- # 获取第一轮诊断
38
  print("\nModel is thinking...", end="", flush=True)
39
- response = bot.generate_response(history)
40
  print(f"\rAssistant: {response}\n")
41
-
42
- # 将助手的回复加入历史
43
  history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
44
 
45
- # 进入多轮对话循环
46
  while True:
47
  try:
48
  user_input = input("User: ")
@@ -51,18 +56,16 @@ def main():
51
  if not user_input.strip():
52
  continue
53
 
54
- # 加入用户的新问题
55
  history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
56
 
57
  print("Model is thinking...", end="", flush=True)
58
- response = bot.generate_response(history)
59
  print(f"\rAssistant: {response}\n")
60
 
61
- # 加入助手的新回复
62
  history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
63
-
64
  except KeyboardInterrupt:
65
  break
66
 
 
67
  if __name__ == "__main__":
68
- main()
 
1
+ from __future__ import annotations
2
+
3
  import argparse
4
+ from pathlib import Path
5
+
6
+ try:
7
+ from .model_utils import (
8
+ DEFAULT_MODEL_PATH,
9
+ SkinGPTModel,
10
+ build_single_turn_messages,
11
+ resolve_model_path,
12
+ )
13
+ except ImportError:
14
+ from model_utils import (
15
+ DEFAULT_MODEL_PATH,
16
+ SkinGPTModel,
17
+ build_single_turn_messages,
18
+ resolve_model_path,
19
+ )
20
 
21
+
22
+ def build_parser() -> argparse.ArgumentParser:
23
+ parser = argparse.ArgumentParser(description="SkinGPT-R1 full-precision multi-turn chat")
24
+ parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
25
  parser.add_argument("--image", type=str, required=True, help="Path to initial image")
26
+ return parser
27
+
28
 
29
+ def main() -> None:
30
+ args = build_parser().parse_args()
31
 
32
+ if not Path(args.image).exists():
 
 
 
 
 
33
  print(f"Error: Image {args.image} not found.")
34
  return
35
 
36
+ model = SkinGPTModel(resolve_model_path(args.model_path))
37
+ history = build_single_turn_messages(
38
+ args.image,
39
+ "Please analyze this image.",
40
+ system_prompt="You are a professional AI dermatology assistant. Analyze the skin condition carefully.",
41
+ )
 
 
 
42
 
43
  print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
44
  print(f"Image loaded: {args.image}")
45
+
 
46
  print("\nModel is thinking...", end="", flush=True)
47
+ response = model.generate_response(history)
48
  print(f"\rAssistant: {response}\n")
 
 
49
  history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
50
 
 
51
  while True:
52
  try:
53
  user_input = input("User: ")
 
56
  if not user_input.strip():
57
  continue
58
 
 
59
  history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
60
 
61
  print("Model is thinking...", end="", flush=True)
62
+ response = model.generate_response(history)
63
  print(f"\rAssistant: {response}\n")
64
 
 
65
  history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
 
66
  except KeyboardInterrupt:
67
  break
68
 
69
+
70
  if __name__ == "__main__":
71
+ main()
inference/{deepseek_service.py → full_precision/deepseek_service.py} RENAMED
@@ -1,75 +1,51 @@
1
- """
2
- DeepSeek API Service
3
- Used to optimize and organize SkinGPT model output results
4
- """
5
 
6
  import os
7
  import re
8
  from typing import Optional
 
9
  from openai import AsyncOpenAI
10
 
11
 
12
  class DeepSeekService:
13
- """DeepSeek API Service Class"""
14
-
15
  def __init__(self, api_key: Optional[str] = None):
16
- """
17
- Initialize DeepSeek service
18
-
19
- Parameters:
20
- api_key: DeepSeek API key, reads from environment variable if not provided
21
- """
22
  self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
23
  self.base_url = "https://api.deepseek.com"
24
- self.model = "deepseek-chat" # Using deepseek-chat model
25
-
26
  self.client = None
27
  self.is_loaded = False
28
-
29
- print(f"DeepSeek API service initializing...")
30
  print(f"API Base URL: {self.base_url}")
31
-
32
  async def load(self):
33
- """Initialize DeepSeek API client"""
34
  try:
35
  if not self.api_key:
36
  print("DeepSeek API key not provided")
37
  self.is_loaded = False
38
  return
39
-
40
- # Initialize OpenAI compatible client
41
- self.client = AsyncOpenAI(
42
- api_key=self.api_key,
43
- base_url=self.base_url
44
- )
45
-
46
  self.is_loaded = True
47
  print("DeepSeek API service is ready!")
48
-
49
- except Exception as e:
50
- print(f"DeepSeek API service initialization failed: {e}")
51
  self.is_loaded = False
52
-
53
  async def refine_diagnosis(
54
- self,
55
  raw_answer: str,
56
  raw_thinking: Optional[str] = None,
57
- language: str = "zh"
58
  ) -> dict:
59
- """
60
- Use DeepSeek API to optimize and organize diagnosis results
61
-
62
- Parameters:
63
- raw_answer: Original diagnosis result
64
- raw_thinking: AI thinking process
65
- language: Language option
66
-
67
- Returns:
68
- Dictionary containing "description", "analysis_process" and "diagnosis_result"
69
- """
70
-
71
  if not self.is_loaded or self.client is None:
72
- error_msg = "API not initialized, cannot generate analysis" if language == "en" else "API未初始化,无法生成分析过程"
 
 
 
 
73
  print("DeepSeek API not initialized, returning original result")
74
  return {
75
  "success": False,
@@ -77,74 +53,67 @@ class DeepSeekService:
77
  "analysis_process": raw_thinking or error_msg,
78
  "diagnosis_result": raw_answer,
79
  "original_diagnosis": raw_answer,
80
- "error": "DeepSeek API not initialized"
81
  }
82
-
83
  try:
84
- # Build prompt
85
  prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
86
-
87
- # Select system prompt based on language
88
- if language == "en":
89
- system_content = "You are a professional medical text editor. Your task is to polish and organize medical diagnostic text to make it flow smoothly while preserving the original meaning. Output ONLY the formatted result. Do NOT add any explanations, comments, or thoughts. Just follow the format exactly."
90
- else:
91
- system_content = "你是医学文本整理专家,按照用户要求将用户输入的文本整理成用户想要的格式,不要改写或总结。"
92
-
93
- # Call DeepSeek API
 
94
  response = await self.client.chat.completions.create(
95
  model=self.model,
96
  messages=[
97
  {"role": "system", "content": system_content},
98
- {"role": "user", "content": prompt}
99
  ],
100
  temperature=0.1,
101
  max_tokens=2048,
102
  top_p=0.8,
103
  )
104
-
105
- # Extract generated text
106
  generated_text = response.choices[0].message.content
107
-
108
- # Parse output
109
  parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
110
-
111
  return {
112
  "success": True,
113
  "description": parsed["description"],
114
  "analysis_process": parsed["analysis_process"],
115
  "diagnosis_result": parsed["diagnosis_result"],
116
  "original_diagnosis": raw_answer,
117
- "raw_refined": generated_text
118
  }
119
-
120
- except Exception as e:
121
- print(f"DeepSeek API call failed: {e}")
122
- error_msg = "API call failed, cannot generate analysis" if language == "en" else "API调用失败,无法生成分析过程"
 
 
 
123
  return {
124
  "success": False,
125
  "description": "",
126
  "analysis_process": raw_thinking or error_msg,
127
  "diagnosis_result": raw_answer,
128
  "original_diagnosis": raw_answer,
129
- "error": str(e)
130
  }
131
-
132
- def _build_refine_prompt(self, raw_answer: str, raw_thinking: Optional[str] = None, language: str = "zh") -> str:
133
- """
134
- Build optimization prompt
135
-
136
- Parameters:
137
- raw_answer: Original diagnosis result
138
- raw_thinking: AI thinking process
139
- language: Language option, "zh" for Chinese, "en" for English
140
-
141
- Returns:
142
- Built prompt
143
- """
144
  if language == "en":
145
- # English prompt - organize and polish while preserving meaning
146
- thinking_text = raw_thinking if raw_thinking else "No analysis process available."
147
- prompt = f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
148
 
149
  【Requirements】
150
  - Preserve the original tone and expression style
@@ -171,21 +140,9 @@ class DeepSeekService:
171
 
172
  ## Diagnosis Result
173
  (The organized diagnosis result from Text 2)
174
-
175
- 【Example】:
176
- ## Description
177
- The image shows red inflamed patches on the skin with pustules and darker colored spots. The lesions appear as papules and pustules distributed across the affected area, with some showing signs of inflammation and possible post-inflammatory hyperpigmentation.
178
-
179
- ## Analysis Process
180
- These findings are consistent with acne vulgaris, commonly seen during adolescence. The user's age aligns with typical onset for this condition. Treatment recommendations: over-the-counter medications such as benzoyl peroxide or topical antibiotics, avoiding picking at the skin, and consulting a dermatologist if severe. The goal is to control inflammation and prevent scarring.
181
-
182
- ## Diagnosis Result
183
- Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition, especially during adolescence, when hormonal changes cause overactive sebaceous glands, which can easily clog pores and form acne. Pathological care recommendations: 1. Keep face clean, wash face 2-3 times daily, use gentle cleansing products. 2. Avoid squeezing acne with hands to prevent worsening inflammation or leaving scars. 3. Avoid using irritating cosmetics and skincare products. 4. Can use topical medications containing salicylic acid, benzoyl peroxide, etc. 5. If necessary, can use oral antibiotics or other treatment methods under doctor's guidance. Precautions: 1. Avoid rubbing or damaging the affected area to prevent infection. 2. Eat less oily and spicy foods, eat more vegetables and fruits. 3. Maintain good rest habits, avoid staying up late. 4. If acne symptoms persist without improvement or show signs of worsening, seek medical attention promptly.
184
  """
185
- else:
186
- # Chinese prompt - translate to Simplified Chinese AND organize/polish
187
- thinking_text = raw_thinking if raw_thinking else "No analysis process available."
188
- prompt = f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
189
 
190
  【要求】
191
  - 保留原文的语气和表达方式
@@ -198,7 +155,6 @@ Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition,
198
  - 禁止推断或添加新的医学信息,禁止输出任何元评论
199
  - 可以调整不合理的语句或去除冗余内容以提高清晰度
200
 
201
-
202
  【文本1】
203
  {thinking_text}
204
 
@@ -214,171 +170,102 @@ Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition,
214
 
215
  ## 诊断结果
216
  (整理后的诊断结果)
217
-
218
- 【样例】:
219
- ## 图像描述
220
- 图片显示皮肤上有红色发炎的斑块,伴有脓疱和颜色较深的斑点。病变表现为分布在受影响区域的丘疹和脓疱,部分显示出炎症迹象和可能的炎症后色素沉着。
221
-
222
- ## 分析过程
223
- 这些表现符合寻常痤疮的特征,青春期常见。用户的年龄与该病症的典型发病年龄相符。治疗建议:使用非处方药物如过氧化苯甲酰或外用抗生素,避免抠抓皮肤,病情严重时咨询皮肤科医生。目标是控制炎症并防止疤痕形成。
224
-
225
- ## 诊断结果
226
- 可能的诊断:痤疮(青春痘) 解释:痤疮是一种常见的皮肤病,特别是在青少年期间,由于激素水平的变化导致皮脂腺过度活跃,容易堵塞毛孔,形成痤疮。 病理护理建议:1.保持面部清洁,每天洗脸2-3次,使用温和的洁面产品。 2.避免用手挤压痤疮,以免加重炎症或留下疤痕。 3.避免使用刺激性的化妆品和护肤品。 4.可以使用含有水杨酸、苯氧醇等成分的外用药物治疗。 5.如有需要,可以在医生指导下使用抗生素口服药或其他治疗方法。 注意事项:1. 避免摩擦或损伤患处,以免引起感染。 2. 饮食上应少吃油腻、辛辣食物,多吃蔬菜水果。 3. 保持良好的作息习惯,避免熬夜。 4. 如果痤疮症状持续不见好转或有恶化的趋势,应及时就医。
227
  """
228
-
229
- return prompt
230
-
231
  def _parse_refined_output(
232
- self,
233
- generated_text: str,
234
  raw_answer: str,
235
  raw_thinking: Optional[str] = None,
236
- language: str = "zh"
237
  ) -> dict:
238
- """
239
- Parse DeepSeek generated output
240
-
241
- Parameters:
242
- generated_text: DeepSeek generated text
243
- raw_answer: Original diagnosis (as fallback)
244
- raw_thinking: Original thinking process (as fallback)
245
- language: Language option
246
-
247
- Returns:
248
- Dictionary containing description, analysis_process and diagnosis_result
249
- """
250
  description = ""
251
  analysis_process = None
252
  diagnosis_result = None
253
-
254
  if language == "en":
255
- # English patterns
256
  desc_match = re.search(
257
- r'##\s*Description\s*\n([\s\S]*?)(?=##\s*Analysis\s*Process|$)',
258
  generated_text,
259
- re.IGNORECASE
260
  )
261
  analysis_match = re.search(
262
- r'##\s*Analysis\s*Process\s*\n([\s\S]*?)(?=##\s*Diagnosis\s*Result|$)',
263
  generated_text,
264
- re.IGNORECASE
265
  )
266
  result_match = re.search(
267
- r'##\s*Diagnosis\s*Result\s*\n([\s\S]*?)$',
268
  generated_text,
269
- re.IGNORECASE
270
  )
271
-
272
  desc_header = "## Description"
273
  analysis_header = "## Analysis Process"
274
  result_header = "## Diagnosis Result"
275
  else:
276
- # Chinese patterns
277
- desc_match = re.search(
278
- r'##\s*图像描述\s*\n([\s\S]*?)(?=##\s*分析过程|$)',
279
- generated_text
280
- )
281
- analysis_match = re.search(
282
- r'##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)',
283
- generated_text
284
- )
285
- result_match = re.search(
286
- r'##\s*诊断结果\s*\n([\s\S]*?)$',
287
- generated_text
288
- )
289
-
290
  desc_header = "## 图像描述"
291
  analysis_header = "## 分析过程"
292
  result_header = "## 诊断结果"
293
-
294
- # Extract description
295
  if desc_match:
296
  description = desc_match.group(1).strip()
297
- print(f"Successfully parsed description")
298
  else:
299
- print(f"Description parsing failed")
300
  description = ""
301
-
302
- # Extract analysis process
303
  if analysis_match:
304
  analysis_process = analysis_match.group(1).strip()
305
- print(f"Successfully parsed analysis process")
306
  else:
307
- print(f"Analysis process parsing failed, trying other methods")
308
- # Try to extract from generated text
309
  result_pos = generated_text.find(result_header)
310
  if result_pos > 0:
311
- # Get content before diagnosis result
312
  analysis_process = generated_text[:result_pos].strip()
313
- # Remove possible headers
314
  for header in [desc_header, analysis_header]:
315
- header_escaped = re.escape(header)
316
- analysis_process = re.sub(f'{header_escaped}\\s*\\n?', '', analysis_process).strip()
317
  else:
318
- # If no format at all, try to get first half
319
- mid_point = len(generated_text) // 2
320
- analysis_process = generated_text[:mid_point].strip()
321
-
322
- # If still empty, use original content (final fallback)
323
  if not analysis_process and raw_thinking:
324
- print(f"Using original raw_thinking as fallback")
325
  analysis_process = raw_thinking
326
-
327
- # Extract diagnosis result
328
  if result_match:
329
  diagnosis_result = result_match.group(1).strip()
330
- print(f"Successfully parsed diagnosis result")
331
  else:
332
- print(f"Diagnosis result parsing failed, trying other methods")
333
- # Try to extract from generated text
334
  result_pos = generated_text.find(result_header)
335
  if result_pos > 0:
336
  diagnosis_result = generated_text[result_pos:].strip()
337
- # Remove possible header
338
- result_header_escaped = re.escape(result_header)
339
- diagnosis_result = re.sub(f'^{result_header_escaped}\\s*\\n?', '', diagnosis_result).strip()
 
 
340
  else:
341
- # If no format at all, get second half
342
- mid_point = len(generated_text) // 2
343
- diagnosis_result = generated_text[mid_point:].strip()
344
-
345
- # If still empty, use original content (final fallback)
346
  if not diagnosis_result:
347
- print(f"Using original raw_answer as fallback")
348
  diagnosis_result = raw_answer
349
-
350
  return {
351
  "description": description,
352
  "analysis_process": analysis_process,
353
- "diagnosis_result": diagnosis_result
354
  }
355
 
356
 
357
- # Global DeepSeek service instance (lazy loading)
358
  _deepseek_service: Optional[DeepSeekService] = None
359
 
360
 
361
  async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
362
- """
363
- Get DeepSeek service instance (singleton pattern)
364
-
365
- Parameters:
366
- api_key: Optional API key to use
367
-
368
- Returns:
369
- DeepSeekService instance, or None if API initialization fails
370
- """
371
  global _deepseek_service
372
-
373
  if _deepseek_service is None:
374
  try:
375
  _deepseek_service = DeepSeekService(api_key=api_key)
376
  await _deepseek_service.load()
377
  if not _deepseek_service.is_loaded:
378
  print("DeepSeek API service initialization failed, will use fallback mode")
379
- return _deepseek_service # Return instance but marked as not loaded
380
- except Exception as e:
381
- print(f"DeepSeek service initialization failed: {e}")
382
  return None
383
-
384
  return _deepseek_service
 
1
+ from __future__ import annotations
 
 
 
2
 
3
  import os
4
  import re
5
  from typing import Optional
6
+
7
  from openai import AsyncOpenAI
8
 
9
 
10
  class DeepSeekService:
11
+ """OpenAI-compatible DeepSeek refinement service."""
12
+
13
  def __init__(self, api_key: Optional[str] = None):
 
 
 
 
 
 
14
  self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
15
  self.base_url = "https://api.deepseek.com"
16
+ self.model = "deepseek-chat"
 
17
  self.client = None
18
  self.is_loaded = False
19
+
20
+ print("DeepSeek API service initializing...")
21
  print(f"API Base URL: {self.base_url}")
22
+
23
  async def load(self):
 
24
  try:
25
  if not self.api_key:
26
  print("DeepSeek API key not provided")
27
  self.is_loaded = False
28
  return
29
+
30
+ self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
 
 
 
 
 
31
  self.is_loaded = True
32
  print("DeepSeek API service is ready!")
33
+ except Exception as exc:
34
+ print(f"DeepSeek API service initialization failed: {exc}")
 
35
  self.is_loaded = False
36
+
37
  async def refine_diagnosis(
38
+ self,
39
  raw_answer: str,
40
  raw_thinking: Optional[str] = None,
41
+ language: str = "zh",
42
  ) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
43
  if not self.is_loaded or self.client is None:
44
+ error_msg = (
45
+ "API not initialized, cannot generate analysis"
46
+ if language == "en"
47
+ else "API未初始化,无法生成分析过程"
48
+ )
49
  print("DeepSeek API not initialized, returning original result")
50
  return {
51
  "success": False,
 
53
  "analysis_process": raw_thinking or error_msg,
54
  "diagnosis_result": raw_answer,
55
  "original_diagnosis": raw_answer,
56
+ "error": "DeepSeek API not initialized",
57
  }
58
+
59
  try:
 
60
  prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
61
+ system_content = (
62
+ "You are a professional medical text editor. Your task is to polish and organize "
63
+ "medical diagnostic text to make it flow smoothly while preserving the original "
64
+ "meaning. Output ONLY the formatted result. Do NOT add any explanations, comments, "
65
+ "or thoughts. Just follow the format exactly."
66
+ if language == "en"
67
+ else "你是医学文本整理专家,按照用户要求将用户输入的文本整理成用户想要的格式,不要改写或总结。"
68
+ )
69
+
70
  response = await self.client.chat.completions.create(
71
  model=self.model,
72
  messages=[
73
  {"role": "system", "content": system_content},
74
+ {"role": "user", "content": prompt},
75
  ],
76
  temperature=0.1,
77
  max_tokens=2048,
78
  top_p=0.8,
79
  )
80
+
 
81
  generated_text = response.choices[0].message.content
 
 
82
  parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
83
+
84
  return {
85
  "success": True,
86
  "description": parsed["description"],
87
  "analysis_process": parsed["analysis_process"],
88
  "diagnosis_result": parsed["diagnosis_result"],
89
  "original_diagnosis": raw_answer,
90
+ "raw_refined": generated_text,
91
  }
92
+ except Exception as exc:
93
+ print(f"DeepSeek API call failed: {exc}")
94
+ error_msg = (
95
+ "API call failed, cannot generate analysis"
96
+ if language == "en"
97
+ else "API调用失败,无法生成分析过程"
98
+ )
99
  return {
100
  "success": False,
101
  "description": "",
102
  "analysis_process": raw_thinking or error_msg,
103
  "diagnosis_result": raw_answer,
104
  "original_diagnosis": raw_answer,
105
+ "error": str(exc),
106
  }
107
+
108
+ def _build_refine_prompt(
109
+ self,
110
+ raw_answer: str,
111
+ raw_thinking: Optional[str] = None,
112
+ language: str = "zh",
113
+ ) -> str:
114
+ thinking_text = raw_thinking if raw_thinking else "No analysis process available."
 
 
 
 
 
115
  if language == "en":
116
+ return f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
 
 
117
 
118
  【Requirements】
119
  - Preserve the original tone and expression style
 
140
 
141
  ## Diagnosis Result
142
  (The organized diagnosis result from Text 2)
 
 
 
 
 
 
 
 
 
 
143
  """
144
+
145
+ return f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
 
 
146
 
147
  【要求】
148
  - 保留原文的语气和表达方式
 
155
  - 禁止推断或添加新的医学信息,禁止输出任何元评论
156
  - 可以调整不合理的语句或去除冗余内容以提高清晰度
157
 
 
158
  【文本1】
159
  {thinking_text}
160
 
 
170
 
171
  ## 诊断结果
172
  (整理后的诊断结果)
 
 
 
 
 
 
 
 
 
 
173
  """
174
+
 
 
175
  def _parse_refined_output(
176
+ self,
177
+ generated_text: str,
178
  raw_answer: str,
179
  raw_thinking: Optional[str] = None,
180
+ language: str = "zh",
181
  ) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
182
  description = ""
183
  analysis_process = None
184
  diagnosis_result = None
185
+
186
  if language == "en":
 
187
  desc_match = re.search(
188
+ r"##\s*Description\s*\n([\s\S]*?)(?=##\s*Analysis\s*Process|$)",
189
  generated_text,
190
+ re.IGNORECASE,
191
  )
192
  analysis_match = re.search(
193
+ r"##\s*Analysis\s*Process\s*\n([\s\S]*?)(?=##\s*Diagnosis\s*Result|$)",
194
  generated_text,
195
+ re.IGNORECASE,
196
  )
197
  result_match = re.search(
198
+ r"##\s*Diagnosis\s*Result\s*\n([\s\S]*?)$",
199
  generated_text,
200
+ re.IGNORECASE,
201
  )
 
202
  desc_header = "## Description"
203
  analysis_header = "## Analysis Process"
204
  result_header = "## Diagnosis Result"
205
  else:
206
+ desc_match = re.search(r"##\s*图像描述\s*\n([\s\S]*?)(?=##\s*分析过程|$)", generated_text)
207
+ analysis_match = re.search(r"##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)", generated_text)
208
+ result_match = re.search(r"##\s*诊断结果\s*\n([\s\S]*?)$", generated_text)
 
 
 
 
 
 
 
 
 
 
 
209
  desc_header = "## 图像描述"
210
  analysis_header = "## 分析过程"
211
  result_header = "## 诊断结果"
212
+
 
213
  if desc_match:
214
  description = desc_match.group(1).strip()
 
215
  else:
 
216
  description = ""
217
+
 
218
  if analysis_match:
219
  analysis_process = analysis_match.group(1).strip()
 
220
  else:
 
 
221
  result_pos = generated_text.find(result_header)
222
  if result_pos > 0:
 
223
  analysis_process = generated_text[:result_pos].strip()
 
224
  for header in [desc_header, analysis_header]:
225
+ analysis_process = re.sub(f"{re.escape(header)}\\s*\\n?", "", analysis_process).strip()
 
226
  else:
227
+ analysis_process = generated_text[: len(generated_text) // 2].strip()
 
 
 
 
228
  if not analysis_process and raw_thinking:
 
229
  analysis_process = raw_thinking
230
+
 
231
  if result_match:
232
  diagnosis_result = result_match.group(1).strip()
 
233
  else:
 
 
234
  result_pos = generated_text.find(result_header)
235
  if result_pos > 0:
236
  diagnosis_result = generated_text[result_pos:].strip()
237
+ diagnosis_result = re.sub(
238
+ f"^{re.escape(result_header)}\\s*\\n?",
239
+ "",
240
+ diagnosis_result,
241
+ ).strip()
242
  else:
243
+ diagnosis_result = generated_text[len(generated_text) // 2 :].strip()
 
 
 
 
244
  if not diagnosis_result:
 
245
  diagnosis_result = raw_answer
246
+
247
  return {
248
  "description": description,
249
  "analysis_process": analysis_process,
250
+ "diagnosis_result": diagnosis_result,
251
  }
252
 
253
 
 
254
  _deepseek_service: Optional[DeepSeekService] = None
255
 
256
 
257
  async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
 
 
 
 
 
 
 
 
 
258
  global _deepseek_service
259
+
260
  if _deepseek_service is None:
261
  try:
262
  _deepseek_service = DeepSeekService(api_key=api_key)
263
  await _deepseek_service.load()
264
  if not _deepseek_service.is_loaded:
265
  print("DeepSeek API service initialization failed, will use fallback mode")
266
+ return _deepseek_service
267
+ except Exception as exc:
268
+ print(f"DeepSeek service initialization failed: {exc}")
269
  return None
270
+
271
  return _deepseek_service
inference/full_precision/demo.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ try:
6
+ from .model_utils import (
7
+ DEFAULT_MODEL_PATH,
8
+ SkinGPTModel,
9
+ build_single_turn_messages,
10
+ resolve_model_path,
11
+ )
12
+ except ImportError:
13
+ from model_utils import (
14
+ DEFAULT_MODEL_PATH,
15
+ SkinGPTModel,
16
+ build_single_turn_messages,
17
+ resolve_model_path,
18
+ )
19
+
20
+ IMAGE_PATH = "test_image.jpg"
21
+ PROMPT = "Please analyze this skin image and provide a diagnosis."
22
+
23
+
24
+ def main() -> None:
25
+ if not Path(IMAGE_PATH).exists():
26
+ print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
27
+ return
28
+
29
+ model = SkinGPTModel(resolve_model_path(DEFAULT_MODEL_PATH))
30
+ messages = build_single_turn_messages(IMAGE_PATH, PROMPT)
31
+
32
+ print("Processing...")
33
+ output_text = model.generate_response(messages)
34
+
35
+ print("\n=== Diagnosis Result ===")
36
+ print(output_text)
37
+ print("========================")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
inference/full_precision/infer.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ try:
7
+ from .model_utils import (
8
+ DEFAULT_MODEL_PATH,
9
+ SkinGPTModel,
10
+ build_single_turn_messages,
11
+ resolve_model_path,
12
+ )
13
+ except ImportError:
14
+ from model_utils import (
15
+ DEFAULT_MODEL_PATH,
16
+ SkinGPTModel,
17
+ build_single_turn_messages,
18
+ resolve_model_path,
19
+ )
20
+
21
+
22
+ def build_parser() -> argparse.ArgumentParser:
23
+ parser = argparse.ArgumentParser(description="SkinGPT-R1 full-precision single inference")
24
+ parser.add_argument("--image", type=str, required=True, help="Path to the image")
25
+ parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
26
+ parser.add_argument(
27
+ "--prompt",
28
+ type=str,
29
+ default="Please analyze this skin image and provide a diagnosis.",
30
+ )
31
+ return parser
32
+
33
+
34
+ def main() -> None:
35
+ args = build_parser().parse_args()
36
+
37
+ if not Path(args.image).exists():
38
+ print(f"Error: Image not found at {args.image}")
39
+ return
40
+
41
+ model = SkinGPTModel(resolve_model_path(args.model_path))
42
+ messages = build_single_turn_messages(args.image, args.prompt)
43
+
44
+ print(f"\nAnalyzing {args.image}...")
45
+ response = model.generate_response(messages)
46
+
47
+ print("-" * 40)
48
+ print("Result:")
49
+ print(response)
50
+ print("-" * 40)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
inference/{model_utils.py → full_precision/model_utils.py} RENAMED
@@ -1,51 +1,96 @@
1
- # model_utils.py
 
 
 
 
 
2
  import torch
3
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
  from qwen_vl_utils import process_vision_info
5
- from PIL import Image
6
- import os
7
- from threading import Thread
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class SkinGPTModel:
10
- def __init__(self, model_path, device=None):
11
- self.model_path = model_path
 
12
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
- print(f"Loading model from {model_path} on {self.device}...")
14
-
15
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
16
- model_path,
17
  torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
18
  attn_implementation="flash_attention_2" if self.device == "cuda" else None,
19
  device_map="auto" if self.device != "mps" else None,
20
- trust_remote_code=True
21
  )
22
-
23
  if self.device == "mps":
24
  self.model = self.model.to(self.device)
25
 
26
  self.processor = AutoProcessor.from_pretrained(
27
- model_path,
28
- trust_remote_code=True,
29
- min_pixels=256*28*28,
30
- max_pixels=1280*28*28
31
  )
32
  print("Model loaded successfully.")
33
 
34
- def generate_response(self, messages, max_new_tokens=1024, temperature=0.7, repetition_penalty=1.2, no_repeat_ngram_size=3):
35
- """
36
- 处理多轮对话的历史消息列表并生成回复
37
- messages format:
38
- [
39
- {'role': 'user', 'content': [{'type': 'image', 'image': 'path...'}, {'type': 'text', 'text': '...'}]},
40
- {'role': 'assistant', 'content': [{'type': 'text', 'text': '...'}]}
41
- ]
42
- """
43
- # 预处理文本模板
44
- text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
-
46
- # 预处理视觉信息
47
  image_inputs, video_inputs = process_vision_info(messages)
48
-
49
  inputs = self.processor(
50
  text=[text],
51
  images=image_inputs,
@@ -62,30 +107,35 @@ class SkinGPTModel:
62
  repetition_penalty=repetition_penalty,
63
  no_repeat_ngram_size=no_repeat_ngram_size,
64
  top_p=0.9,
65
- do_sample=True
66
  )
67
 
68
- # 解码输出 (去除输入的token)
69
  generated_ids_trimmed = [
70
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
71
  ]
72
  output_text = self.processor.batch_decode(
73
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
 
 
74
  )
75
-
76
  return output_text[0]
77
-
78
- def generate_response_stream(self, messages, max_new_tokens=1024, temperature=0.7, repetition_penalty=1.2, no_repeat_ngram_size=3):
79
- """
80
- 流式生成响应
81
- 返回一个生成器,逐个yield生成的文本chunk
82
- """
83
- # 预处理文本模板
84
- text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
85
-
86
- # 预处理视觉信息
 
 
 
 
87
  image_inputs, video_inputs = process_vision_info(messages)
88
-
89
  inputs = self.processor(
90
  text=[text],
91
  images=image_inputs,
@@ -93,15 +143,13 @@ class SkinGPTModel:
93
  padding=True,
94
  return_tensors="pt",
95
  ).to(self.model.device)
96
-
97
- # 创建 TextIteratorStreamer 用于流式输出
98
  streamer = TextIteratorStreamer(
99
  self.processor.tokenizer,
100
  skip_prompt=True,
101
- skip_special_tokens=True
102
  )
103
-
104
- # 准备生成参数
105
  generation_kwargs = {
106
  **inputs,
107
  "max_new_tokens": max_new_tokens,
@@ -112,13 +160,11 @@ class SkinGPTModel:
112
  "do_sample": True,
113
  "streamer": streamer,
114
  }
115
-
116
- # 在单独的线程中运行生成
117
  thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
118
  thread.start()
119
-
120
- # 逐个yield生成的文本
121
  for text_chunk in streamer:
122
  yield text_chunk
123
-
124
- thread.join()
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from threading import Thread
5
+ from typing import List
6
+
7
  import torch
 
8
  from qwen_vl_utils import process_vision_info
9
+ from transformers import (
10
+ AutoProcessor,
11
+ Qwen2_5_VLForConditionalGeneration,
12
+ TextIteratorStreamer,
13
+ )
14
+
15
+ DEFAULT_MODEL_PATH = "./checkpoints/full_precision"
16
+ DEFAULT_SYSTEM_PROMPT = "You are a professional AI dermatology assistant."
17
+
18
+
19
+ def resolve_model_path(model_path: str = DEFAULT_MODEL_PATH) -> str:
20
+ """Resolve a model path for both cloned-repo and local-dev layouts."""
21
+ raw_path = Path(model_path).expanduser()
22
+ repo_root = Path(__file__).resolve().parents[2]
23
+ candidates = [raw_path]
24
+
25
+ if not raw_path.is_absolute():
26
+ candidates.append(Path.cwd() / raw_path)
27
+ candidates.append(repo_root / raw_path)
28
+ if raw_path.parts and raw_path.parts[0] == repo_root.name:
29
+ candidates.append(repo_root.joinpath(*raw_path.parts[1:]))
30
+
31
+ for candidate in candidates:
32
+ if candidate.exists():
33
+ return str(candidate)
34
+ return str(raw_path)
35
+
36
+
37
+ def build_single_turn_messages(
38
+ image_path: str,
39
+ prompt: str,
40
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT,
41
+ ) -> List[dict]:
42
+ return [
43
+ {
44
+ "role": "user",
45
+ "content": [
46
+ {"type": "image", "image": image_path},
47
+ {"type": "text", "text": f"{system_prompt}\n\n{prompt}"},
48
+ ],
49
+ }
50
+ ]
51
+
52
 
53
  class SkinGPTModel:
54
+ def __init__(self, model_path: str = DEFAULT_MODEL_PATH, device: str | None = None):
55
+ resolved_model_path = resolve_model_path(model_path)
56
+ self.model_path = resolved_model_path
57
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
58
+ print(f"Loading model from {resolved_model_path} on {self.device}...")
59
+
60
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
61
+ resolved_model_path,
62
  torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
63
  attn_implementation="flash_attention_2" if self.device == "cuda" else None,
64
  device_map="auto" if self.device != "mps" else None,
65
+ trust_remote_code=True,
66
  )
67
+
68
  if self.device == "mps":
69
  self.model = self.model.to(self.device)
70
 
71
  self.processor = AutoProcessor.from_pretrained(
72
+ resolved_model_path,
73
+ trust_remote_code=True,
74
+ min_pixels=256 * 28 * 28,
75
+ max_pixels=1280 * 28 * 28,
76
  )
77
  print("Model loaded successfully.")
78
 
79
+ def generate_response(
80
+ self,
81
+ messages,
82
+ max_new_tokens: int = 1024,
83
+ temperature: float = 0.7,
84
+ repetition_penalty: float = 1.2,
85
+ no_repeat_ngram_size: int = 3,
86
+ ) -> str:
87
+ text = self.processor.apply_chat_template(
88
+ messages,
89
+ tokenize=False,
90
+ add_generation_prompt=True,
91
+ )
92
  image_inputs, video_inputs = process_vision_info(messages)
93
+
94
  inputs = self.processor(
95
  text=[text],
96
  images=image_inputs,
 
107
  repetition_penalty=repetition_penalty,
108
  no_repeat_ngram_size=no_repeat_ngram_size,
109
  top_p=0.9,
110
+ do_sample=True,
111
  )
112
 
 
113
  generated_ids_trimmed = [
114
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
115
  ]
116
  output_text = self.processor.batch_decode(
117
+ generated_ids_trimmed,
118
+ skip_special_tokens=True,
119
+ clean_up_tokenization_spaces=False,
120
  )
121
+
122
  return output_text[0]
123
+
124
+ def generate_response_stream(
125
+ self,
126
+ messages,
127
+ max_new_tokens: int = 1024,
128
+ temperature: float = 0.7,
129
+ repetition_penalty: float = 1.2,
130
+ no_repeat_ngram_size: int = 3,
131
+ ):
132
+ text = self.processor.apply_chat_template(
133
+ messages,
134
+ tokenize=False,
135
+ add_generation_prompt=True,
136
+ )
137
  image_inputs, video_inputs = process_vision_info(messages)
138
+
139
  inputs = self.processor(
140
  text=[text],
141
  images=image_inputs,
 
143
  padding=True,
144
  return_tensors="pt",
145
  ).to(self.model.device)
146
+
 
147
  streamer = TextIteratorStreamer(
148
  self.processor.tokenizer,
149
  skip_prompt=True,
150
+ skip_special_tokens=True,
151
  )
152
+
 
153
  generation_kwargs = {
154
  **inputs,
155
  "max_new_tokens": max_new_tokens,
 
160
  "do_sample": True,
161
  "streamer": streamer,
162
  }
163
+
 
164
  thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
165
  thread.start()
166
+
 
167
  for text_chunk in streamer:
168
  yield text_chunk
169
+
170
+ thread.join()
inference/full_precision/run_api.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ PYTHON_EXE="${PYTHON_EXE:-python}"
5
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
6
+ "${PYTHON_EXE}" "${SCRIPT_DIR}/app.py"
inference/full_precision/run_chat.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ PYTHON_EXE="${PYTHON_EXE:-python}"
5
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
6
+ "${PYTHON_EXE}" "${SCRIPT_DIR}/chat.py" "$@"
inference/full_precision/run_infer.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ PYTHON_EXE="${PYTHON_EXE:-python}"
5
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
6
+ "${PYTHON_EXE}" "${SCRIPT_DIR}/infer.py" "$@"
inference/inference.py DELETED
@@ -1,43 +0,0 @@
1
-
2
- import argparse
3
- from model_utils import SkinGPTModel
4
- import os
5
-
6
- def main():
7
- parser = argparse.ArgumentParser(description="SkinGPT-R1 Single Inference")
8
- parser.add_argument("--image", type=str, required=True, help="Path to the image")
9
- parser.add_argument("--model_path", type=str, default="../checkpoint")
10
- parser.add_argument("--prompt", type=str, default="Please analyze this skin image and provide a diagnosis.")
11
- args = parser.parse_args()
12
-
13
- if not os.path.exists(args.image):
14
- print(f"Error: Image not found at {args.image}")
15
- return
16
-
17
- # 1. 加载模型 (复用 model_utils)
18
- # 这样你就不用在这里重复写 transformers 的加载代码了
19
- bot = SkinGPTModel(args.model_path)
20
-
21
- # 2. 构造单轮消息
22
- system_prompt = "You are a professional AI dermatology assistant."
23
- messages = [
24
- {
25
- "role": "user",
26
- "content": [
27
- {"type": "image", "image": args.image},
28
- {"type": "text", "text": f"{system_prompt}\n\n{args.prompt}"}
29
- ]
30
- }
31
- ]
32
-
33
- # 3. 推理
34
- print(f"\nAnalyzing {args.image}...")
35
- response = bot.generate_response(messages)
36
-
37
- print("-" * 40)
38
- print("Result:")
39
- print(response)
40
- print("-" * 40)
41
-
42
- if __name__ == "__main__":
43
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/int4_quantized/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """INT4 quantized inference package for SkinGPT-R1."""
inference/int4_quantized/__pycache__/app.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
inference/int4_quantized/__pycache__/chat.cpython-311.pyc ADDED
Binary file (3.49 kB). View file
 
inference/int4_quantized/__pycache__/infer.cpython-311.pyc ADDED
Binary file (4.48 kB). View file
 
inference/int4_quantized/__pycache__/model_utils.cpython-311.pyc ADDED
Binary file (28.9 kB). View file
 
inference/{.ipynb_checkpoints/app-checkpoint.py → int4_quantized/app.py} RENAMED
@@ -1,133 +1,97 @@
1
- # app.py
2
- import uvicorn
 
 
3
  import os
4
  import shutil
 
5
  import uuid
6
- import json
7
- import re
8
- import asyncio
9
- from typing import Optional
10
- from io import BytesIO
11
  from contextlib import asynccontextmanager
12
- from PIL import Image
13
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
 
 
 
 
 
 
 
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from fastapi.responses import StreamingResponse
16
- from fastapi.concurrency import run_in_threadpool
17
- from model_utils import SkinGPTModel
18
- from deepseek_service import get_deepseek_service, DeepSeekService
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # === Configuration ===
21
- MODEL_PATH = "../checkpoint"
22
- TEMP_DIR = "./temp_uploads"
23
- os.makedirs(TEMP_DIR, exist_ok=True)
 
24
 
25
- # DeepSeek API Key
26
- DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY", "sk-b221f29be052460f9e0fe12d88dd343c")
 
27
 
28
- # Global DeepSeek service instance
29
  deepseek_service: Optional[DeepSeekService] = None
30
 
31
- @asynccontextmanager
32
- async def lifespan(app: FastAPI):
33
- """应用生命周期管理"""
34
- # 启动时初始化 DeepSeek 服务
35
- await init_deepseek()
36
- yield
37
- print("\nShutting down service...")
38
 
39
- app = FastAPI(
40
- title="SkinGPT-R1 皮肤诊断系统",
41
- description="智能皮肤诊断助手",
42
- version="1.0.0",
43
- lifespan=lifespan
44
- )
45
 
46
- # CORS配置 - 允许前端访问
47
- app.add_middleware(
48
- CORSMiddleware,
49
- allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
50
- allow_credentials=True,
51
- allow_methods=["*"],
52
- allow_headers=["*"],
53
- )
54
 
55
- # 全局变量存储状态
56
- # chat_states: 存储对话历史 (List of messages for Qwen)
57
- # pending_images: 存储已上传但尚未发送给LLM的图片路径 (State ID -> Image Path)
58
- chat_states = {}
59
- pending_images = {}
60
 
61
- def parse_diagnosis_result(raw_text: str) -> dict:
62
- """
63
- 解析诊断结果中的think和answer标签
64
-
65
- 参数:
66
- - raw_text: 原始诊断文本
67
-
68
- 返回:
69
- - dict: 包含thinking, answer, raw字段的字典
70
- """
71
- import re
72
-
73
- # 尝试匹配完整的标签
74
- think_match = re.search(r'<think>([\s\S]*?)</think>', raw_text)
75
- answer_match = re.search(r'<answer>([\s\S]*?)</answer>', raw_text)
76
-
77
- thinking = None
78
- answer = None
79
-
80
- # 处理think标签
81
- if think_match:
82
- thinking = think_match.group(1).strip()
83
- else:
84
- # 尝试匹配未闭合的think标签(输出被截断的情况)
85
- unclosed_think = re.search(r'<think>([\s\S]*?)(?=<answer>|$)', raw_text)
86
  if unclosed_think:
87
  thinking = unclosed_think.group(1).strip()
88
-
89
- # 处理answer标签
90
- if answer_match:
91
- answer = answer_match.group(1).strip()
92
- else:
93
- # 尝试匹配未闭合的answer标签
94
- unclosed_answer = re.search(r'<answer>([\s\S]*?)$', raw_text)
95
  if unclosed_answer:
96
  answer = unclosed_answer.group(1).strip()
97
-
98
- # 如果仍然没有找到answer,清理原始文本作为answer
99
  if not answer:
100
- # 移除所有标签及其内容
101
- cleaned = re.sub(r'<think>[\s\S]*?</think>', '', raw_text)
102
- cleaned = re.sub(r'<think>[\s\S]*', '', cleaned) # 移除未闭合的think
103
- cleaned = re.sub(r'</?answer>', '', cleaned) # 移除answer标签
104
- cleaned = cleaned.strip()
105
- answer = cleaned if cleaned else raw_text
106
-
107
- # 清理可能残留的标签
108
- if answer:
109
- answer = re.sub(r'</?think>|</?answer>', '', answer).strip()
110
- if thinking:
111
- thinking = re.sub(r'</?think>|</?answer>', '', thinking).strip()
112
-
113
- # 处理 "Final Answer:" 格式,提取其后的内容
114
  if answer:
115
- final_answer_match = re.search(r'Final Answer:\s*([\s\S]*)', answer, re.IGNORECASE)
 
116
  if final_answer_match:
117
  answer = final_answer_match.group(1).strip()
118
-
119
- return {
120
- "thinking": thinking if thinking else None,
121
- "answer": answer,
122
- "raw": raw_text
123
- }
124
 
125
- print("Initializing Model Service...")
126
- # 全局加载模型
127
- gpt_model = SkinGPTModel(MODEL_PATH)
128
- print("Service Ready.")
 
 
 
 
 
 
129
 
130
- # 初始化 DeepSeek 服务(异步)
131
  async def init_deepseek():
132
  global deepseek_service
133
  print("\nInitializing DeepSeek service...")
@@ -137,120 +101,115 @@ async def init_deepseek():
137
  else:
138
  print("DeepSeek service not available, will return raw results")
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  @app.post("/v1/upload/{state_id}")
141
  async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
142
- """
143
- 接收图片上传。
144
- 逻辑:将图片保存到本地临时目录,并标记该 state_id 有一张待处理图片。
145
- """
146
  try:
147
- # 1. 保存图片到本地临时文件
148
  file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
149
  unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
150
- file_path = os.path.join(TEMP_DIR, unique_name)
151
-
152
- with open(file_path, "wb") as buffer:
153
  shutil.copyfileobj(file.file, buffer)
154
-
155
- # 2. 记录图片路径等待下一次 predict 调用时使用
156
- # 如果是多图模式,这里可以改成 list,目前演示单图覆盖或更新
157
- pending_images[state_id] = file_path
158
-
159
- # 3. 初始化对话状态(如果是新会话)
160
  if state_id not in chat_states:
161
  chat_states[state_id] = []
162
-
163
- return {"message": "Image uploaded successfully", "path": file_path}
164
-
165
- except Exception as e:
166
- raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
167
 
168
  @app.post("/v1/predict/{state_id}")
169
  async def v1_predict(request: Request, state_id: str):
170
- """
171
- 接收文本并执行推理。
172
- 逻辑:检查是否有待处理图片。如果有,将其与文本组合成 multimodal 消息。
173
- """
174
  try:
175
  data = await request.json()
176
- except:
177
- raise HTTPException(status_code=400, detail="Invalid JSON")
178
-
179
  user_message = data.get("message", "")
180
  if not user_message:
181
  raise HTTPException(status_code=400, detail="Missing 'message' field")
182
 
183
- # 获取或初始化历史
184
  history = chat_states.get(state_id, [])
185
-
186
- # 构建当前轮次的用户内容
187
  current_content = []
188
-
189
- # 1. 检查是否有刚刚上传的图片
190
  if state_id in pending_images:
191
- img_path = pending_images.pop(state_id) # 取出并移除
192
  current_content.append({"type": "image", "image": img_path})
193
-
194
- # ��果是第一次对话,加上 System Prompt
195
  if not history:
196
- system_prompt = "You are a professional AI dermatology assistant. "
197
- user_message = f"{system_prompt}\n\n{user_message}"
198
 
199
- # 2. 添加文本
200
  current_content.append({"type": "text", "text": user_message})
201
-
202
- # 3. 更新历史
203
  history.append({"role": "user", "content": current_content})
204
  chat_states[state_id] = history
205
 
206
- # 4. 运行推理 (在线程池中运行以防阻塞)
207
  try:
208
- response_text = await run_in_threadpool(
209
- gpt_model.generate_response,
210
- messages=history
211
- )
212
- except Exception as e:
213
- # 回滚历史(移除刚才出错的用户提问)
214
  chat_states[state_id].pop()
215
- raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
216
 
217
- # 5. 将回复加入历史
218
  history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
219
  chat_states[state_id] = history
220
-
221
  return {"message": response_text}
222
 
 
223
  @app.post("/v1/reset/{state_id}")
224
  async def reset_chat(state_id: str):
225
- """清除会话状态"""
226
  if state_id in chat_states:
227
  del chat_states[state_id]
228
  if state_id in pending_images:
229
- # 可选:删除临时文件
230
  try:
231
- os.remove(pending_images[state_id])
232
- except:
233
  pass
234
  del pending_images[state_id]
235
  return {"message": "Chat history reset"}
236
 
 
237
  @app.get("/")
238
  async def root():
239
- """根路径"""
240
  return {
241
- "name": "SkinGPT-R1 皮肤诊断系统",
242
- "version": "1.0.0",
243
  "status": "running",
244
- "description": "智能皮肤诊断助手"
245
  }
246
 
 
247
  @app.get("/health")
248
  async def health_check():
249
- """健康检查"""
250
- return {
251
- "status": "healthy",
252
- "model_loaded": True
253
- }
254
 
255
  @app.post("/diagnose/stream")
256
  async def diagnose_stream(
@@ -258,126 +217,89 @@ async def diagnose_stream(
258
  text: str = Form(...),
259
  language: str = Form("zh"),
260
  ):
261
- """
262
- SSE流式诊断接口(用于前端)
263
- 支持图片上传和文本输入,返回真正的流式响应
264
- 使用 DeepSeek API 优化输出格式
265
- """
266
- from queue import Queue, Empty
267
- from threading import Thread
268
-
269
  language = language if language in ("zh", "en") else "zh"
270
-
271
- # 处理图片
272
  pil_image = None
273
- temp_image_path = None
274
-
275
  if image:
276
  contents = await image.read()
277
  pil_image = Image.open(BytesIO(contents)).convert("RGB")
278
-
279
- # 创建队列用于线程间通信
280
  result_queue = Queue()
281
- # 用于存储完整响应和解析结果
282
  generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
283
-
284
  def run_generation():
285
- """在后台线程中运行流式生成"""
286
  full_response = []
287
-
288
  try:
289
- # 构建消息
290
  messages = []
291
  current_content = []
292
-
293
- # 添加系统提示
294
- system_prompt = "You are a professional AI dermatology assistant." if language == "en" else "你是一个专业的AI皮肤科助手。"
295
-
296
- # 如果有图片,保存到临时文件
 
297
  if pil_image:
298
- generation_result["temp_image_path"] = os.path.join(TEMP_DIR, f"temp_{uuid.uuid4().hex}.jpg")
299
- pil_image.save(generation_result["temp_image_path"])
300
- current_content.append({"type": "image", "image": generation_result["temp_image_path"]})
301
-
302
- # 添加文本
303
- prompt = f"{system_prompt}\n\n{text}"
304
- current_content.append({"type": "text", "text": prompt})
305
  messages.append({"role": "user", "content": current_content})
306
-
307
- # 流式生成 - 每个 chunk 立即���入队列
308
  for chunk in gpt_model.generate_response_stream(
309
  messages=messages,
310
- max_new_tokens=2048,
311
- temperature=0.7
 
312
  ):
313
  full_response.append(chunk)
314
  result_queue.put(("delta", chunk))
315
-
316
- # 解析结果
317
  response_text = "".join(full_response)
318
- parsed = parse_diagnosis_result(response_text)
319
  generation_result["full_response"] = full_response
320
- generation_result["parsed"] = parsed
321
-
322
- # 标记生成完成
323
  result_queue.put(("generation_done", None))
324
-
325
- except Exception as e:
326
- result_queue.put(("error", str(e)))
327
-
328
  async def event_generator():
329
- """异步生成SSE事件"""
330
- # 在后台线程启动生成(非阻塞)
331
  gen_thread = Thread(target=run_generation)
332
  gen_thread.start()
333
-
334
  loop = asyncio.get_event_loop()
335
-
336
- # 从队列中读取并发送流式内容
337
  while True:
338
  try:
339
- # 非阻塞获取
340
  msg_type, data = await loop.run_in_executor(
341
- None,
342
- lambda: result_queue.get(timeout=0.1)
343
  )
344
-
345
  if msg_type == "generation_done":
346
- # 流式生成完成,准备处理最终结果
347
  break
348
- elif msg_type == "delta":
349
- yield_chunk = json.dumps({"type": "delta", "text": data}, ensure_ascii=False)
350
- yield f"data: {yield_chunk}\n\n"
351
  elif msg_type == "error":
352
  yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
353
  gen_thread.join()
354
  return
355
-
356
  except Empty:
357
- # 队列暂时为空,继续等待
358
  await asyncio.sleep(0.01)
359
- continue
360
-
361
  gen_thread.join()
362
-
363
- # 获取解析结果
364
  parsed = generation_result["parsed"]
365
  if not parsed:
366
- yield f"data: {json.dumps({'type': 'error', 'message': 'Failed to parse response'}, ensure_ascii=False)}\n\n"
367
  return
368
-
369
  raw_thinking = parsed["thinking"]
370
  raw_answer = parsed["answer"]
371
-
372
- # 使用 DeepSeek 优化结果
373
  refined_by_deepseek = False
374
  description = None
375
  thinking = raw_thinking
376
  answer = raw_answer
377
-
378
  if deepseek_service and deepseek_service.is_loaded:
379
  try:
380
- print(f"Calling DeepSeek to refine diagnosis (language={language})...")
381
  refined = await deepseek_service.refine_diagnosis(
382
  raw_answer=raw_answer,
383
  raw_thinking=raw_thinking,
@@ -388,36 +310,35 @@ async def diagnose_stream(
388
  thinking = refined["analysis_process"]
389
  answer = refined["diagnosis_result"]
390
  refined_by_deepseek = True
391
- print(f"DeepSeek refinement completed successfully")
392
- except Exception as e:
393
- print(f"DeepSeek refinement failed, using original: {e}")
394
  else:
395
  print("DeepSeek service not available, using raw results")
396
-
397
- success_msg = "Diagnosis completed" if language == "en" else "诊断完成"
398
-
399
- # 返回格式与参考项目保持一致
400
  final_payload = {
401
- "description": description, # 图片描述(从 thinking 中提取)
402
- "thinking": thinking, # 分析过程(DeepSeek 优化后)
403
- "answer": answer, # 诊断结果(DeepSeek 优化后)
404
- "raw": parsed["raw"], # 原始响应
405
- "refined_by_deepseek": refined_by_deepseek, # 是否被 DeepSeek 优化
406
  "success": True,
407
- "message": success_msg
408
  }
409
- yield_final = json.dumps({"type": "final", "result": final_payload}, ensure_ascii=False)
410
- yield f"data: {yield_final}\n\n"
411
-
412
- # 清理临时图片
413
  temp_path = generation_result.get("temp_image_path")
414
- if temp_path and os.path.exists(temp_path):
415
  try:
416
- os.remove(temp_path)
417
- except:
418
  pass
419
-
420
  return StreamingResponse(event_generator(), media_type="text/event-stream")
421
 
422
- if __name__ == '__main__':
423
- uvicorn.run("app:app", host="0.0.0.0", port=5900, reload=False)
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
  import os
6
  import shutil
7
+ import sys
8
  import uuid
 
 
 
 
 
9
  from contextlib import asynccontextmanager
10
+ from io import BytesIO
11
+ from pathlib import Path
12
+ from queue import Empty, Queue
13
+ from threading import Thread
14
+ from typing import Optional
15
+
16
+ import uvicorn
17
+ from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
18
+ from fastapi.concurrency import run_in_threadpool
19
  from fastapi.middleware.cors import CORSMiddleware
20
  from fastapi.responses import StreamingResponse
21
+ from PIL import Image
22
+
23
+ try:
24
+ from .model_utils import (
25
+ DEFAULT_DO_SAMPLE,
26
+ DEFAULT_MAX_NEW_TOKENS,
27
+ DEFAULT_MODEL_PATH,
28
+ DEFAULT_REPETITION_PENALTY,
29
+ QuantizedSkinGPTModel,
30
+ )
31
+ except ImportError:
32
+ from model_utils import (
33
+ DEFAULT_DO_SAMPLE,
34
+ DEFAULT_MAX_NEW_TOKENS,
35
+ DEFAULT_MODEL_PATH,
36
+ DEFAULT_REPETITION_PENALTY,
37
+ QuantizedSkinGPTModel,
38
+ )
39
 
40
+ try:
41
+ from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service
42
+ except ImportError:
43
+ sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
44
+ from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service
45
 
46
+ TEMP_DIR = Path(__file__).resolve().parents[1] / "temp_uploads"
47
+ TEMP_DIR.mkdir(parents=True, exist_ok=True)
48
+ DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
49
 
 
50
  deepseek_service: Optional[DeepSeekService] = None
51
 
 
 
 
 
 
 
 
52
 
53
+ def parse_diagnosis_result(raw_text: str) -> dict:
54
+ import re
 
 
 
 
55
 
56
+ think_match = re.search(r"<think>([\s\S]*?)</think>", raw_text)
57
+ answer_match = re.search(r"<answer>([\s\S]*?)</answer>", raw_text)
 
 
 
 
 
 
58
 
59
+ thinking = think_match.group(1).strip() if think_match else None
60
+ answer = answer_match.group(1).strip() if answer_match else None
 
 
 
61
 
62
+ if not thinking:
63
+ unclosed_think = re.search(r"<think>([\s\S]*?)(?=<answer>|$)", raw_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  if unclosed_think:
65
  thinking = unclosed_think.group(1).strip()
66
+
67
+ if not answer:
68
+ unclosed_answer = re.search(r"<answer>([\s\S]*?)$", raw_text)
 
 
 
 
69
  if unclosed_answer:
70
  answer = unclosed_answer.group(1).strip()
71
+
 
72
  if not answer:
73
+ cleaned = re.sub(r"<think>[\s\S]*?</think>", "", raw_text)
74
+ cleaned = re.sub(r"<think>[\s\S]*", "", cleaned)
75
+ cleaned = re.sub(r"</?answer>", "", cleaned)
76
+ answer = cleaned.strip() or raw_text
77
+
 
 
 
 
 
 
 
 
 
78
  if answer:
79
+ answer = re.sub(r"</?think>|</?answer>", "", answer).strip()
80
+ final_answer_match = re.search(r"Final Answer:\s*([\s\S]*)", answer, re.IGNORECASE)
81
  if final_answer_match:
82
  answer = final_answer_match.group(1).strip()
 
 
 
 
 
 
83
 
84
+ if thinking:
85
+ thinking = re.sub(r"</?think>|</?answer>", "", thinking).strip()
86
+
87
+ return {"thinking": thinking or None, "answer": answer, "raw": raw_text}
88
+
89
+
90
+ print("Initializing INT4 Model Service...")
91
+ gpt_model = QuantizedSkinGPTModel(DEFAULT_MODEL_PATH)
92
+ print("INT4 service ready.")
93
+
94
 
 
95
  async def init_deepseek():
96
  global deepseek_service
97
  print("\nInitializing DeepSeek service...")
 
101
  else:
102
  print("DeepSeek service not available, will return raw results")
103
 
104
+
105
+ @asynccontextmanager
106
+ async def lifespan(app: FastAPI):
107
+ await init_deepseek()
108
+ yield
109
+ print("\nShutting down INT4 service...")
110
+
111
+
112
+ app = FastAPI(
113
+ title="SkinGPT-R1 INT4 API",
114
+ description="INT4 quantized dermatology assistant backend",
115
+ version="1.1.0",
116
+ lifespan=lifespan,
117
+ )
118
+
119
+ app.add_middleware(
120
+ CORSMiddleware,
121
+ allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
122
+ allow_credentials=True,
123
+ allow_methods=["*"],
124
+ allow_headers=["*"],
125
+ )
126
+
127
+ chat_states = {}
128
+ pending_images = {}
129
+
130
+
131
  @app.post("/v1/upload/{state_id}")
132
  async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
133
+ del survey
 
 
 
134
  try:
 
135
  file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
136
  unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
137
+ file_path = TEMP_DIR / unique_name
138
+
139
+ with file_path.open("wb") as buffer:
140
  shutil.copyfileobj(file.file, buffer)
141
+
142
+ pending_images[state_id] = str(file_path)
 
 
 
 
143
  if state_id not in chat_states:
144
  chat_states[state_id] = []
145
+
146
+ return {"message": "Image uploaded successfully", "path": str(file_path)}
147
+ except Exception as exc:
148
+ raise HTTPException(status_code=500, detail=f"Upload failed: {exc}") from exc
149
+
150
 
151
  @app.post("/v1/predict/{state_id}")
152
  async def v1_predict(request: Request, state_id: str):
 
 
 
 
153
  try:
154
  data = await request.json()
155
+ except Exception as exc:
156
+ raise HTTPException(status_code=400, detail="Invalid JSON") from exc
157
+
158
  user_message = data.get("message", "")
159
  if not user_message:
160
  raise HTTPException(status_code=400, detail="Missing 'message' field")
161
 
 
162
  history = chat_states.get(state_id, [])
 
 
163
  current_content = []
164
+
 
165
  if state_id in pending_images:
166
+ img_path = pending_images.pop(state_id)
167
  current_content.append({"type": "image", "image": img_path})
 
 
168
  if not history:
169
+ user_message = f"You are a professional AI dermatology assistant.\n\n{user_message}"
 
170
 
 
171
  current_content.append({"type": "text", "text": user_message})
 
 
172
  history.append({"role": "user", "content": current_content})
173
  chat_states[state_id] = history
174
 
 
175
  try:
176
+ response_text = await run_in_threadpool(gpt_model.generate_response, messages=history)
177
+ except Exception as exc:
 
 
 
 
178
  chat_states[state_id].pop()
179
+ raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc
180
 
 
181
  history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
182
  chat_states[state_id] = history
 
183
  return {"message": response_text}
184
 
185
+
186
  @app.post("/v1/reset/{state_id}")
187
  async def reset_chat(state_id: str):
 
188
  if state_id in chat_states:
189
  del chat_states[state_id]
190
  if state_id in pending_images:
 
191
  try:
192
+ Path(pending_images[state_id]).unlink(missing_ok=True)
193
+ except Exception:
194
  pass
195
  del pending_images[state_id]
196
  return {"message": "Chat history reset"}
197
 
198
+
199
  @app.get("/")
200
  async def root():
 
201
  return {
202
+ "name": "SkinGPT-R1 INT4 API",
203
+ "version": "1.1.0",
204
  "status": "running",
205
+ "description": "INT4 quantized dermatology assistant",
206
  }
207
 
208
+
209
  @app.get("/health")
210
  async def health_check():
211
+ return {"status": "healthy", "model_loaded": True}
212
+
 
 
 
213
 
214
  @app.post("/diagnose/stream")
215
  async def diagnose_stream(
 
217
  text: str = Form(...),
218
  language: str = Form("zh"),
219
  ):
 
 
 
 
 
 
 
 
220
  language = language if language in ("zh", "en") else "zh"
 
 
221
  pil_image = None
222
+
 
223
  if image:
224
  contents = await image.read()
225
  pil_image = Image.open(BytesIO(contents)).convert("RGB")
226
+
 
227
  result_queue = Queue()
 
228
  generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
229
+
230
  def run_generation():
 
231
  full_response = []
 
232
  try:
 
233
  messages = []
234
  current_content = []
235
+ system_prompt = (
236
+ "You are a professional AI dermatology assistant."
237
+ if language == "en"
238
+ else "你是一个专业的AI皮肤科助手。"
239
+ )
240
+
241
  if pil_image:
242
+ temp_image_path = TEMP_DIR / f"temp_{uuid.uuid4().hex}.jpg"
243
+ pil_image.save(temp_image_path)
244
+ generation_result["temp_image_path"] = str(temp_image_path)
245
+ current_content.append({"type": "image", "image": str(temp_image_path)})
246
+
247
+ current_content.append({"type": "text", "text": f"{system_prompt}\n\n{text}"})
 
248
  messages.append({"role": "user", "content": current_content})
249
+
 
250
  for chunk in gpt_model.generate_response_stream(
251
  messages=messages,
252
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
253
+ do_sample=DEFAULT_DO_SAMPLE,
254
+ repetition_penalty=DEFAULT_REPETITION_PENALTY,
255
  ):
256
  full_response.append(chunk)
257
  result_queue.put(("delta", chunk))
258
+
 
259
  response_text = "".join(full_response)
 
260
  generation_result["full_response"] = full_response
261
+ generation_result["parsed"] = parse_diagnosis_result(response_text)
 
 
262
  result_queue.put(("generation_done", None))
263
+ except Exception as exc:
264
+ result_queue.put(("error", str(exc)))
265
+
 
266
  async def event_generator():
 
 
267
  gen_thread = Thread(target=run_generation)
268
  gen_thread.start()
269
+
270
  loop = asyncio.get_event_loop()
 
 
271
  while True:
272
  try:
 
273
  msg_type, data = await loop.run_in_executor(
274
+ None,
275
+ lambda: result_queue.get(timeout=0.1),
276
  )
 
277
  if msg_type == "generation_done":
 
278
  break
279
+ if msg_type == "delta":
280
+ yield f"data: {json.dumps({'type': 'delta', 'text': data}, ensure_ascii=False)}\n\n"
 
281
  elif msg_type == "error":
282
  yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
283
  gen_thread.join()
284
  return
 
285
  except Empty:
 
286
  await asyncio.sleep(0.01)
287
+
 
288
  gen_thread.join()
 
 
289
  parsed = generation_result["parsed"]
290
  if not parsed:
291
+ yield "data: {\"type\": \"error\", \"message\": \"Failed to parse response\"}\n\n"
292
  return
293
+
294
  raw_thinking = parsed["thinking"]
295
  raw_answer = parsed["answer"]
 
 
296
  refined_by_deepseek = False
297
  description = None
298
  thinking = raw_thinking
299
  answer = raw_answer
300
+
301
  if deepseek_service and deepseek_service.is_loaded:
302
  try:
 
303
  refined = await deepseek_service.refine_diagnosis(
304
  raw_answer=raw_answer,
305
  raw_thinking=raw_thinking,
 
310
  thinking = refined["analysis_process"]
311
  answer = refined["diagnosis_result"]
312
  refined_by_deepseek = True
313
+ except Exception as exc:
314
+ print(f"DeepSeek refinement failed, using original: {exc}")
 
315
  else:
316
  print("DeepSeek service not available, using raw results")
317
+
 
 
 
318
  final_payload = {
319
+ "description": description,
320
+ "thinking": thinking,
321
+ "answer": answer,
322
+ "raw": parsed["raw"],
323
+ "refined_by_deepseek": refined_by_deepseek,
324
  "success": True,
325
+ "message": "Diagnosis completed" if language == "en" else "诊断完成",
326
  }
327
+ yield f"data: {json.dumps({'type': 'final', 'result': final_payload}, ensure_ascii=False)}\n\n"
328
+
 
 
329
  temp_path = generation_result.get("temp_image_path")
330
+ if temp_path:
331
  try:
332
+ Path(temp_path).unlink(missing_ok=True)
333
+ except Exception:
334
  pass
335
+
336
  return StreamingResponse(event_generator(), media_type="text/event-stream")
337
 
338
+
339
+ def main() -> None:
340
+ uvicorn.run("app:app", host="0.0.0.0", port=5901, reload=False)
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()
inference/{.ipynb_checkpoints/chat-checkpoint.py → int4_quantized/chat.py} RENAMED
@@ -1,48 +1,51 @@
1
- # chat.py
 
2
  import argparse
3
- import os
4
- from model_utils import SkinGPTModel
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def main():
7
- parser = argparse.ArgumentParser(description="SkinGPT-R1 Multi-turn Chat")
8
- parser.add_argument("--model_path", type=str, default="../checkpoint")
9
  parser.add_argument("--image", type=str, required=True, help="Path to initial image")
10
- args = parser.parse_args()
11
 
12
- # 初始化模型
13
- bot = SkinGPTModel(args.model_path)
14
 
15
- # 初始化对话历史
16
- # 系统提示词
17
- system_prompt = "You are a professional AI dermatology assistant. Analyze the skin condition carefully."
18
-
19
- # 构造第一条包含图片的消息
20
- if not os.path.exists(args.image):
21
  print(f"Error: Image {args.image} not found.")
22
  return
23
 
24
- history = [
25
- {
26
- "role": "user",
27
- "content": [
28
- {"type": "image", "image": args.image},
29
- {"type": "text", "text": f"{system_prompt}\n\nPlease analyze this image."}
30
- ]
31
- }
32
- ]
33
 
34
- print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
35
  print(f"Image loaded: {args.image}")
36
-
37
- # 获取第一轮诊断
38
  print("\nModel is thinking...", end="", flush=True)
39
- response = bot.generate_response(history)
40
  print(f"\rAssistant: {response}\n")
41
-
42
- # 将助手的回复加入历史
43
  history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
44
 
45
- # 进入多轮对话循环
46
  while True:
47
  try:
48
  user_input = input("User: ")
@@ -51,18 +54,14 @@ def main():
51
  if not user_input.strip():
52
  continue
53
 
54
- # 加入用户的新问题
55
  history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
56
-
57
  print("Model is thinking...", end="", flush=True)
58
- response = bot.generate_response(history)
59
  print(f"\rAssistant: {response}\n")
60
-
61
- # 加入助手的新回复
62
  history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
63
-
64
  except KeyboardInterrupt:
65
  break
66
 
 
67
  if __name__ == "__main__":
68
- main()
 
1
+ from __future__ import annotations
2
+
3
  import argparse
4
+ from pathlib import Path
5
+
6
+ try:
7
+ from .model_utils import (
8
+ DEFAULT_MODEL_PATH,
9
+ QuantizedSkinGPTModel,
10
+ build_single_turn_messages,
11
+ )
12
+ except ImportError:
13
+ from model_utils import (
14
+ DEFAULT_MODEL_PATH,
15
+ QuantizedSkinGPTModel,
16
+ build_single_turn_messages,
17
+ )
18
+
19
 
20
+ def build_parser() -> argparse.ArgumentParser:
21
+ parser = argparse.ArgumentParser(description="SkinGPT-R1 INT4 multi-turn chat")
22
+ parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
23
  parser.add_argument("--image", type=str, required=True, help="Path to initial image")
24
+ return parser
25
 
 
 
26
 
27
+ def main() -> None:
28
+ args = build_parser().parse_args()
29
+
30
+ if not Path(args.image).exists():
 
 
31
  print(f"Error: Image {args.image} not found.")
32
  return
33
 
34
+ model = QuantizedSkinGPTModel(args.model_path)
35
+ history = build_single_turn_messages(
36
+ args.image,
37
+ "Please analyze this image.",
38
+ system_prompt="You are a professional AI dermatology assistant. Analyze the skin condition carefully.",
39
+ )
 
 
 
40
 
41
+ print("\n=== SkinGPT-R1 INT4 Chat (Type 'exit' to quit) ===")
42
  print(f"Image loaded: {args.image}")
43
+
 
44
  print("\nModel is thinking...", end="", flush=True)
45
+ response = model.generate_response(history)
46
  print(f"\rAssistant: {response}\n")
 
 
47
  history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
48
 
 
49
  while True:
50
  try:
51
  user_input = input("User: ")
 
54
  if not user_input.strip():
55
  continue
56
 
 
57
  history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
 
58
  print("Model is thinking...", end="", flush=True)
59
+ response = model.generate_response(history)
60
  print(f"\rAssistant: {response}\n")
 
 
61
  history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
 
62
  except KeyboardInterrupt:
63
  break
64
 
65
+
66
  if __name__ == "__main__":
67
+ main()
inference/int4_quantized/infer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import time
5
+ from pathlib import Path
6
+
7
+ try:
8
+ from .model_utils import (
9
+ DEFAULT_DO_SAMPLE,
10
+ DEFAULT_MODEL_PATH,
11
+ DEFAULT_MAX_NEW_TOKENS,
12
+ DEFAULT_PROMPT,
13
+ DEFAULT_REPETITION_PENALTY,
14
+ DEFAULT_TEMPERATURE,
15
+ DEFAULT_TOP_P,
16
+ QuantizedSkinGPTModel,
17
+ build_single_turn_messages,
18
+ )
19
+ except ImportError:
20
+ from model_utils import (
21
+ DEFAULT_DO_SAMPLE,
22
+ DEFAULT_MODEL_PATH,
23
+ DEFAULT_MAX_NEW_TOKENS,
24
+ DEFAULT_PROMPT,
25
+ DEFAULT_REPETITION_PENALTY,
26
+ DEFAULT_TEMPERATURE,
27
+ DEFAULT_TOP_P,
28
+ QuantizedSkinGPTModel,
29
+ build_single_turn_messages,
30
+ )
31
+
32
+
33
+ def build_parser() -> argparse.ArgumentParser:
34
+ parser = argparse.ArgumentParser(description="SkinGPT-R1 INT4 inference")
35
+ parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
36
+ parser.add_argument("--image_path", type=str, required=True, help="Path to the test image")
37
+ parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt for diagnosis")
38
+ parser.add_argument("--max_new_tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS)
39
+ parser.add_argument("--do_sample", action="store_true", default=DEFAULT_DO_SAMPLE)
40
+ parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE)
41
+ parser.add_argument("--top_p", type=float, default=DEFAULT_TOP_P)
42
+ parser.add_argument("--repetition_penalty", type=float, default=DEFAULT_REPETITION_PENALTY)
43
+ return parser
44
+
45
+
46
+ def main() -> None:
47
+ args = build_parser().parse_args()
48
+
49
+ if not Path(args.image_path).exists():
50
+ print(f"Error: Image not found at {args.image_path}")
51
+ return
52
+
53
+ print("=== [1] Initializing INT4 Quantization ===")
54
+ print("BitsAndBytesConfig will be applied during model loading.")
55
+
56
+ print("=== [2] Loading Model and Processor ===")
57
+ start_load = time.time()
58
+ model = QuantizedSkinGPTModel(args.model_path)
59
+ print(f"Model loaded in {time.time() - start_load:.2f} seconds.")
60
+
61
+ print("=== [3] Preparing Input ===")
62
+ messages = build_single_turn_messages(args.image_path, args.prompt)
63
+
64
+ print("=== [4] Generating Response ===")
65
+ start_infer = time.time()
66
+ output_text = model.generate_response(
67
+ messages,
68
+ max_new_tokens=args.max_new_tokens,
69
+ do_sample=args.do_sample,
70
+ temperature=args.temperature,
71
+ top_p=args.top_p,
72
+ repetition_penalty=args.repetition_penalty,
73
+ )
74
+
75
+ print(f"Inference completed in {time.time() - start_infer:.2f} seconds.")
76
+ print("\n================ MODEL OUTPUT ================\n")
77
+ print(output_text)
78
+ print("\n==============================================\n")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()
inference/int4_quantized/model_utils.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from threading import Thread
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from qwen_vl_utils import process_vision_info
11
+ from transformers import (
12
+ AutoProcessor,
13
+ BitsAndBytesConfig,
14
+ StoppingCriteria,
15
+ StoppingCriteriaList,
16
+ TextIteratorStreamer,
17
+ )
18
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
19
+ Qwen2_5_VLForConditionalGeneration,
20
+ )
21
+
22
+ DEFAULT_MODEL_PATH = "./checkpoints/int4"
23
+ DEFAULT_SYSTEM_PROMPT = (
24
+ "You are a professional AI dermatology assistant. "
25
+ "Reason step by step, keep the reasoning concise, avoid repetition, "
26
+ "and always finish with <answer>...</answer>."
27
+ )
28
+ DEFAULT_MAX_NEW_TOKENS = 768
29
+ DEFAULT_CONTINUE_TOKENS = 256
30
+ DEFAULT_DO_SAMPLE = False
31
+ DEFAULT_TEMPERATURE = 0.2
32
+ DEFAULT_TOP_P = 0.9
33
+ DEFAULT_REPETITION_PENALTY = 1.15
34
+ DEFAULT_NO_REPEAT_NGRAM_SIZE = 3
35
+ DEFAULT_PROMPT = (
36
+ "Act as a dermatologist. Analyze the visual features of this skin lesion "
37
+ "step by step, and provide a final diagnosis."
38
+ )
39
+
40
+
41
+ def resolve_model_path(model_path: str = DEFAULT_MODEL_PATH) -> str:
42
+ raw_path = Path(model_path).expanduser()
43
+ repo_root = Path(__file__).resolve().parents[2]
44
+ candidates = [raw_path]
45
+
46
+ if not raw_path.is_absolute():
47
+ candidates.append(Path.cwd() / raw_path)
48
+ candidates.append(repo_root / raw_path)
49
+ if raw_path.parts and raw_path.parts[0] == repo_root.name:
50
+ candidates.append(repo_root.joinpath(*raw_path.parts[1:]))
51
+
52
+ for candidate in candidates:
53
+ if candidate.exists():
54
+ return str(candidate)
55
+ return str(raw_path)
56
+
57
+
58
+ def build_single_turn_messages(
59
+ image_path: str,
60
+ prompt: str,
61
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT,
62
+ ) -> list[dict]:
63
+ return [
64
+ {
65
+ "role": "user",
66
+ "content": [
67
+ {"type": "image", "image": image_path},
68
+ {"type": "text", "text": f"{system_prompt}\n\n{prompt}"},
69
+ ],
70
+ }
71
+ ]
72
+
73
+
74
+ def build_quantization_config() -> BitsAndBytesConfig:
75
+ return BitsAndBytesConfig(
76
+ load_in_4bit=True,
77
+ bnb_4bit_quant_type="nf4",
78
+ bnb_4bit_compute_dtype=torch.bfloat16,
79
+ bnb_4bit_use_double_quant=True,
80
+ )
81
+
82
+
83
+ def resolve_quantized_device_map():
84
+ if not torch.cuda.is_available():
85
+ raise RuntimeError("INT4 quantized inference requires a CUDA GPU.")
86
+ return {"": f"cuda:{torch.cuda.current_device()}"}
87
+
88
+
89
+ class StopOnTokenSequence(StoppingCriteria):
90
+ def __init__(self, stop_ids: list[int]):
91
+ super().__init__()
92
+ self.stop_ids = stop_ids
93
+ self.stop_length = len(stop_ids)
94
+
95
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
96
+ if self.stop_length == 0 or input_ids.shape[1] < self.stop_length:
97
+ return False
98
+ return input_ids[0, -self.stop_length :].tolist() == self.stop_ids
99
+
100
+
101
+ class ExpertBlock(nn.Module):
102
+ def __init__(self, hidden_dim, bottleneck_dim=64):
103
+ super().__init__()
104
+ self.net = nn.Sequential(
105
+ nn.Linear(hidden_dim, bottleneck_dim),
106
+ nn.ReLU(),
107
+ nn.Linear(bottleneck_dim, hidden_dim),
108
+ )
109
+
110
+ def forward(self, x):
111
+ return self.net(x)
112
+
113
+
114
+ class SkinAwareMoEAdapter(nn.Module):
115
+ def __init__(self, hidden_dim, num_experts=8, top_k=2, bottleneck_dim=64):
116
+ super().__init__()
117
+ self.num_experts = num_experts
118
+ self.top_k = top_k
119
+ self.router_img = nn.Linear(hidden_dim, num_experts, bias=False)
120
+ self.router_skin = nn.Linear(3, num_experts, bias=False)
121
+ self.experts = nn.ModuleList(
122
+ [ExpertBlock(hidden_dim, bottleneck_dim) for _ in range(num_experts)]
123
+ )
124
+
125
+ def forward(self, x: torch.Tensor, skin_probs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
126
+ img_logits = self.router_img(x)
127
+ skin_bias = self.router_skin(skin_probs)
128
+ router_logits = img_logits + skin_bias
129
+ router_probs = F.softmax(router_logits, dim=-1)
130
+
131
+ top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
132
+ top_k_probs = top_k_probs / (top_k_probs.sum(dim=-1, keepdim=True) + 1e-6)
133
+
134
+ final_output = torch.zeros_like(x)
135
+ for expert_idx, expert in enumerate(self.experts):
136
+ expert_mask = top_k_indices == expert_idx
137
+ if expert_mask.any():
138
+ rows, k_indices = torch.where(expert_mask)
139
+ inp = x[rows]
140
+ out = expert(inp)
141
+ weights = top_k_probs[rows, k_indices].unsqueeze(-1)
142
+ final_output.index_add_(0, rows, (out * weights).to(final_output.dtype))
143
+
144
+ mean_prob = router_probs.mean(0)
145
+ mask_all = torch.zeros_like(router_probs)
146
+ mask_all.scatter_(1, top_k_indices, 1.0)
147
+ mean_freq = mask_all.mean(0)
148
+ aux_loss = (mean_prob * mean_freq).sum() * self.num_experts
149
+
150
+ return x + final_output, aux_loss
151
+
152
+
153
+ class PatchDistillHead(nn.Module):
154
+ def __init__(
155
+ self,
156
+ embed_dim: int = 1024,
157
+ adapter_layers: int = 4,
158
+ in_dim: Optional[int] = None,
159
+ out_dim: Optional[int] = None,
160
+ num_experts: int = 8,
161
+ top_k: int = 2,
162
+ ):
163
+ super().__init__()
164
+ self.embed_dim = embed_dim
165
+ self.in_proj = None if in_dim is None else nn.Linear(in_dim, embed_dim, bias=False)
166
+ self.skin_classifier = nn.Sequential(
167
+ nn.Linear(embed_dim, 64),
168
+ nn.ReLU(),
169
+ nn.Linear(64, 3),
170
+ )
171
+ self.adapters = nn.ModuleList(
172
+ [
173
+ SkinAwareMoEAdapter(embed_dim, num_experts=num_experts, top_k=top_k)
174
+ for _ in range(adapter_layers)
175
+ ]
176
+ )
177
+ self.out_proj: nn.Module = (
178
+ nn.Identity() if out_dim is None else nn.Linear(embed_dim, out_dim)
179
+ )
180
+
181
+ def _ensure_in_proj(self, din: int, device, dtype):
182
+ if self.in_proj is None:
183
+ self.in_proj = nn.Linear(din, self.embed_dim, bias=False).to(device=device, dtype=dtype)
184
+
185
+ def forward(self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor) -> dict:
186
+ _, din = pixel_values.shape
187
+ counts = (image_grid_thw[:, 0] * image_grid_thw[:, 1] * image_grid_thw[:, 2]).tolist()
188
+ device, dtype = pixel_values.device, pixel_values.dtype
189
+ self._ensure_in_proj(din, device, dtype)
190
+ chunks = torch.split(pixel_values, counts, dim=0)
191
+
192
+ pooled, all_skin_logits = [], []
193
+ total_aux_loss = torch.tensor(0.0, device=device, dtype=dtype)
194
+
195
+ for x in chunks:
196
+ h = self.in_proj(x)
197
+ global_feat = h.mean(dim=0, keepdim=True)
198
+ skin_logits = self.skin_classifier(global_feat)
199
+ skin_probs = F.softmax(skin_logits, dim=-1)
200
+ all_skin_logits.append(skin_logits)
201
+ skin_probs_expanded = skin_probs.expand(h.size(0), -1)
202
+
203
+ for adapter in self.adapters:
204
+ h, layer_loss = adapter(h, skin_probs_expanded)
205
+ total_aux_loss += layer_loss
206
+ pooled.append(h.mean(dim=0))
207
+
208
+ vision_embed = torch.stack(pooled, dim=0)
209
+ vision_proj = self.out_proj(vision_embed)
210
+ return {
211
+ "vision_embed": vision_embed,
212
+ "vision_proj": vision_proj,
213
+ "aux_loss": total_aux_loss,
214
+ "skin_logits": torch.cat(all_skin_logits, dim=0),
215
+ }
216
+
217
+ def configure_out_dim(self, out_dim: int):
218
+ if isinstance(self.out_proj, nn.Linear) and self.out_proj.out_features == out_dim:
219
+ return
220
+ self.out_proj = (
221
+ nn.Linear(self.embed_dim, out_dim, bias=False)
222
+ if out_dim != self.embed_dim
223
+ else nn.Identity()
224
+ )
225
+ try:
226
+ params = next(self.parameters())
227
+ self.out_proj.to(device=params.device, dtype=params.dtype)
228
+ except StopIteration:
229
+ pass
230
+
231
+
232
+ class SkinVLModelWithAdapter(Qwen2_5_VLForConditionalGeneration):
233
+ def __init__(self, config):
234
+ super().__init__(config)
235
+ self.distill_head = PatchDistillHead(
236
+ embed_dim=1024,
237
+ adapter_layers=4,
238
+ num_experts=8,
239
+ top_k=2,
240
+ in_dim=1176,
241
+ )
242
+ bottleneck = 64
243
+ self.text_bias = nn.Sequential(
244
+ nn.Linear(1024, bottleneck, bias=False),
245
+ nn.Tanh(),
246
+ nn.Linear(bottleneck, config.hidden_size, bias=False),
247
+ )
248
+ self.logit_bias_scale = nn.Parameter(torch.tensor(2.5, dtype=torch.bfloat16))
249
+
250
+ def forward(self, *args, **kwargs):
251
+ skin_vocab_mask = kwargs.pop("skin_vocab_mask", None)
252
+ skin_labels = kwargs.get("skin_labels", None)
253
+ pixel_values = kwargs.get("pixel_values", None)
254
+ image_grid_thw = kwargs.get("image_grid_thw", None)
255
+
256
+ if isinstance(pixel_values, list):
257
+ try:
258
+ pixel_values = torch.stack(pixel_values)
259
+ kwargs["pixel_values"] = pixel_values
260
+ except Exception:
261
+ pass
262
+
263
+ outputs = super().forward(*args, **kwargs)
264
+
265
+ vision_embed = None
266
+ loss_skin = torch.tensor(0.0, device=outputs.logits.device)
267
+ aux_loss = torch.tensor(0.0, device=outputs.logits.device)
268
+
269
+ if pixel_values is not None and image_grid_thw is not None:
270
+ if not isinstance(pixel_values, torch.Tensor):
271
+ if isinstance(pixel_values, list):
272
+ pixel_values = torch.stack(pixel_values)
273
+ else:
274
+ pixel_values = torch.tensor(pixel_values)
275
+
276
+ image_grid_thw = image_grid_thw.to(pixel_values.device)
277
+ side = self.distill_head(pixel_values=pixel_values, image_grid_thw=image_grid_thw)
278
+ vision_embed = side["vision_embed"]
279
+ aux_loss = side["aux_loss"]
280
+
281
+ if skin_labels is not None:
282
+ skin_labels = skin_labels.to(side["skin_logits"].device)
283
+ loss_skin = nn.CrossEntropyLoss()(side["skin_logits"], skin_labels)
284
+
285
+ setattr(outputs, "vision_embed", vision_embed)
286
+ setattr(outputs, "vision_proj", side["vision_proj"])
287
+ setattr(outputs, "loss_skin", loss_skin)
288
+ setattr(outputs, "aux_loss", aux_loss)
289
+ setattr(outputs, "skin_logits", side["skin_logits"])
290
+
291
+ pack_vision_proj = (
292
+ side["vision_proj"]
293
+ if side["vision_proj"] is not None
294
+ else torch.tensor(0.0, device=aux_loss.device)
295
+ )
296
+ pack_skin_logits = (
297
+ side["skin_logits"]
298
+ if side["skin_logits"] is not None
299
+ else torch.tensor(0.0, device=aux_loss.device)
300
+ )
301
+ outputs.attentions = (pack_vision_proj, aux_loss, pack_skin_logits)
302
+
303
+ self.latest_side_output = {
304
+ "vision_proj": side["vision_proj"],
305
+ "aux_loss": aux_loss,
306
+ "skin_logits": side["skin_logits"],
307
+ }
308
+
309
+ if hasattr(outputs, "logits") and vision_embed is not None and skin_vocab_mask is not None:
310
+ bias_features = self.text_bias(vision_embed.to(self.logit_bias_scale.dtype))
311
+ lm_weight = self.lm_head.weight.to(bias_features.dtype)
312
+ vocab_bias = F.linear(bias_features, lm_weight)
313
+ scale = self.logit_bias_scale.to(outputs.logits.dtype)
314
+ outputs.logits = outputs.logits + (scale * vocab_bias[:, None, :] * skin_vocab_mask)
315
+
316
+ if outputs.loss is not None:
317
+ outputs.loss = outputs.loss + loss_skin + (0.01 * aux_loss)
318
+
319
+ return outputs
320
+
321
+ def freeze_all_but_distill(self):
322
+ self.requires_grad_(False)
323
+ for params in self.distill_head.parameters():
324
+ params.requires_grad_(True)
325
+ for params in self.text_bias.parameters():
326
+ params.requires_grad_(True)
327
+ self.logit_bias_scale.requires_grad_(True)
328
+
329
+ def configure_out_dim(self, out_dim: int):
330
+ self.distill_head.configure_out_dim(out_dim)
331
+
332
+ def project_only(self, vision_embed: torch.Tensor) -> torch.Tensor:
333
+ return self.distill_head.out_proj(vision_embed)
334
+
335
+
336
+ def load_quantized_model_and_processor(model_path: str = DEFAULT_MODEL_PATH):
337
+ resolved_model_path = resolve_model_path(model_path)
338
+ quantization_config = build_quantization_config()
339
+ model = SkinVLModelWithAdapter.from_pretrained(
340
+ resolved_model_path,
341
+ device_map=resolve_quantized_device_map(),
342
+ quantization_config=quantization_config,
343
+ attn_implementation="sdpa",
344
+ )
345
+ model.eval()
346
+ processor = AutoProcessor.from_pretrained(
347
+ resolved_model_path,
348
+ min_pixels=256 * 28 * 28,
349
+ max_pixels=1280 * 28 * 28,
350
+ )
351
+ return model, processor
352
+
353
+
354
+ def get_model_device(model) -> torch.device:
355
+ try:
356
+ return model.device
357
+ except AttributeError:
358
+ return next(model.parameters()).device
359
+
360
+
361
+ def prepare_inputs(processor, model, messages: list[dict]):
362
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
363
+ image_inputs, video_inputs = process_vision_info(messages)
364
+ inputs = processor(
365
+ text=[text],
366
+ images=image_inputs,
367
+ videos=video_inputs,
368
+ padding=True,
369
+ return_tensors="pt",
370
+ ).to(get_model_device(model))
371
+ inputs.pop("mm_token_type_ids", None)
372
+ return inputs
373
+
374
+
375
+ class QuantizedSkinGPTModel:
376
+ def __init__(self, model_path: str = DEFAULT_MODEL_PATH):
377
+ resolved_model_path = resolve_model_path(model_path)
378
+ print(f"Loading INT4 model from {resolved_model_path}...")
379
+ self.model, self.processor = load_quantized_model_and_processor(resolved_model_path)
380
+ self.model_path = resolved_model_path
381
+ self.device = get_model_device(self.model)
382
+ self.stop_ids = self.processor.tokenizer.encode("</answer>", add_special_tokens=False)
383
+ print(f"Model loaded successfully on {self.device}.")
384
+
385
+ @staticmethod
386
+ def has_complete_answer(text: str) -> bool:
387
+ return "<answer>" in text and "</answer>" in text
388
+
389
+ def _build_generation_kwargs(
390
+ self,
391
+ inputs,
392
+ max_new_tokens: int,
393
+ do_sample: bool,
394
+ temperature: float,
395
+ repetition_penalty: float,
396
+ top_p: float,
397
+ no_repeat_ngram_size: int,
398
+ streamer=None,
399
+ ) -> dict:
400
+ generation_kwargs = {
401
+ **inputs,
402
+ "max_new_tokens": max_new_tokens,
403
+ "do_sample": do_sample,
404
+ "repetition_penalty": repetition_penalty,
405
+ "no_repeat_ngram_size": no_repeat_ngram_size,
406
+ "use_cache": True,
407
+ "stopping_criteria": StoppingCriteriaList([StopOnTokenSequence(self.stop_ids)]),
408
+ }
409
+ if streamer is not None:
410
+ generation_kwargs["streamer"] = streamer
411
+ if do_sample:
412
+ generation_kwargs["temperature"] = temperature
413
+ generation_kwargs["top_p"] = top_p
414
+ return generation_kwargs
415
+
416
+ def _generate_text(
417
+ self,
418
+ messages,
419
+ max_new_tokens: int,
420
+ do_sample: bool,
421
+ temperature: float,
422
+ repetition_penalty: float,
423
+ top_p: float,
424
+ no_repeat_ngram_size: int,
425
+ ) -> str:
426
+ inputs = prepare_inputs(self.processor, self.model, messages)
427
+ generation_kwargs = self._build_generation_kwargs(
428
+ inputs=inputs,
429
+ max_new_tokens=max_new_tokens,
430
+ do_sample=do_sample,
431
+ temperature=temperature,
432
+ repetition_penalty=repetition_penalty,
433
+ top_p=top_p,
434
+ no_repeat_ngram_size=no_repeat_ngram_size,
435
+ )
436
+
437
+ with torch.inference_mode():
438
+ generated_ids = self.model.generate(**generation_kwargs)
439
+
440
+ generated_ids_trimmed = [
441
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
442
+ ]
443
+ output_text = self.processor.batch_decode(
444
+ generated_ids_trimmed,
445
+ skip_special_tokens=True,
446
+ clean_up_tokenization_spaces=False,
447
+ )
448
+ return output_text[0]
449
+
450
+ def generate_response(
451
+ self,
452
+ messages,
453
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
454
+ continue_tokens: int = DEFAULT_CONTINUE_TOKENS,
455
+ do_sample: bool = DEFAULT_DO_SAMPLE,
456
+ temperature: float = DEFAULT_TEMPERATURE,
457
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
458
+ top_p: float = DEFAULT_TOP_P,
459
+ no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
460
+ ) -> str:
461
+ output_text = self._generate_text(
462
+ messages=messages,
463
+ max_new_tokens=max_new_tokens,
464
+ do_sample=do_sample,
465
+ temperature=temperature,
466
+ repetition_penalty=repetition_penalty,
467
+ top_p=top_p,
468
+ no_repeat_ngram_size=no_repeat_ngram_size,
469
+ )
470
+ if not self.has_complete_answer(output_text) and continue_tokens > 0:
471
+ output_text = self._generate_text(
472
+ messages=messages,
473
+ max_new_tokens=max_new_tokens + continue_tokens,
474
+ do_sample=do_sample,
475
+ temperature=temperature,
476
+ repetition_penalty=repetition_penalty,
477
+ top_p=top_p,
478
+ no_repeat_ngram_size=no_repeat_ngram_size,
479
+ )
480
+ return output_text
481
+
482
+ def generate_response_stream(
483
+ self,
484
+ messages,
485
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
486
+ continue_tokens: int = DEFAULT_CONTINUE_TOKENS,
487
+ do_sample: bool = DEFAULT_DO_SAMPLE,
488
+ temperature: float = DEFAULT_TEMPERATURE,
489
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
490
+ top_p: float = DEFAULT_TOP_P,
491
+ no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
492
+ ):
493
+ inputs = prepare_inputs(self.processor, self.model, messages)
494
+ streamer = TextIteratorStreamer(
495
+ self.processor.tokenizer,
496
+ skip_prompt=True,
497
+ skip_special_tokens=True,
498
+ )
499
+ generation_kwargs = self._build_generation_kwargs(
500
+ inputs=inputs,
501
+ max_new_tokens=max_new_tokens,
502
+ do_sample=do_sample,
503
+ temperature=temperature,
504
+ repetition_penalty=repetition_penalty,
505
+ top_p=top_p,
506
+ no_repeat_ngram_size=no_repeat_ngram_size,
507
+ streamer=streamer,
508
+ )
509
+
510
+ def _generate():
511
+ with torch.inference_mode():
512
+ self.model.generate(**generation_kwargs)
513
+
514
+ thread = Thread(target=_generate)
515
+ thread.start()
516
+
517
+ partial_chunks = []
518
+ for text_chunk in streamer:
519
+ partial_chunks.append(text_chunk)
520
+ yield text_chunk
521
+
522
+ thread.join()
523
+
524
+ partial_text = "".join(partial_chunks)
525
+ if not self.has_complete_answer(partial_text) and continue_tokens > 0:
526
+ completed_text = self._generate_text(
527
+ messages=messages,
528
+ max_new_tokens=max_new_tokens + continue_tokens,
529
+ do_sample=do_sample,
530
+ temperature=temperature,
531
+ repetition_penalty=repetition_penalty,
532
+ top_p=top_p,
533
+ no_repeat_ngram_size=no_repeat_ngram_size,
534
+ )
535
+ if completed_text.startswith(partial_text):
536
+ tail_text = completed_text[len(partial_text) :]
537
+ if tail_text:
538
+ yield tail_text
inference/int4_quantized/run_api.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ PYTHON_EXE="${PYTHON_EXE:-python}"
5
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
6
+ "${PYTHON_EXE}" "${SCRIPT_DIR}/app.py"
inference/int4_quantized/run_chat.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ PYTHON_EXE="${PYTHON_EXE:-python}"
5
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
6
+ "${PYTHON_EXE}" "${SCRIPT_DIR}/chat.py" "$@"
inference/int4_quantized/run_infer.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ PYTHON_EXE="${PYTHON_EXE:-python}"
5
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
6
+ "${PYTHON_EXE}" "${SCRIPT_DIR}/infer.py" "$@"
inference/int4_quantized/test_single.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
5
+
6
+ "${SCRIPT_DIR}/run_infer.sh" "$@"
inference/temp_uploads/.ipynb_checkpoints/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef-checkpoint.jpg DELETED
Binary file (47.9 kB)
 
inference/temp_uploads/.ipynb_checkpoints/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c-checkpoint.jpg DELETED
Binary file (80.7 kB)
 
inference/temp_uploads/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef.jpg DELETED
Binary file (47.9 kB)
 
inference/temp_uploads/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c.jpg DELETED
Binary file (80.7 kB)
 
requirements.txt CHANGED
@@ -14,6 +14,10 @@ fastapi>=0.100.0
14
  uvicorn>=0.20.0
15
  python-multipart>=0.0.6
16
  openai>=1.0.0 # For DeepSeek API (OpenAI-compatible)
 
 
 
 
17
 
18
  # Install latest transformers from source (Required for Qwen2.5-VL/Vision-R1)
19
  git+https://github.com/huggingface/transformers.git
@@ -23,4 +27,4 @@ git+https://github.com/huggingface/transformers.git
23
 
24
  # For potential future demo usage
25
  gradio==5.4.0
26
- gradio_client==1.4.2
 
14
  uvicorn>=0.20.0
15
  python-multipart>=0.0.6
16
  openai>=1.0.0 # For DeepSeek API (OpenAI-compatible)
17
+ bitsandbytes>=0.43.0 # Required for INT4 quantized inference
18
+ # Attention notes:
19
+ # - SDPA is built into PyTorch 2.x
20
+ # - flash-attn is optional and mainly useful on GPUs officially supported by the project
21
 
22
  # Install latest transformers from source (Required for Qwen2.5-VL/Vision-R1)
23
  git+https://github.com/huggingface/transformers.git
 
27
 
28
  # For potential future demo usage
29
  gradio==5.4.0
30
+ gradio_client==1.4.2