Yuhao commited on
Commit ·
52a881a
1
Parent(s): f7f33b5
Restructure inference and add INT4 serving
Browse files- LICENSE +9 -0
- README.md +109 -56
- inference/.ipynb_checkpoints/deepseek_service-checkpoint.py +0 -384
- inference/.ipynb_checkpoints/demo-checkpoint.py +0 -76
- inference/.ipynb_checkpoints/inference-checkpoint.py +0 -43
- inference/.ipynb_checkpoints/model_utils-checkpoint.py +0 -120
- inference/README.md +11 -0
- inference/__init__.py +1 -0
- inference/__pycache__/app.cpython-311.pyc +0 -0
- inference/__pycache__/deepseek_service.cpython-311.pyc +0 -0
- inference/__pycache__/model_utils.cpython-311.pyc +0 -0
- inference/demo.py +0 -79
- inference/full_precision/__init__.py +1 -0
- inference/full_precision/__pycache__/app.cpython-311.pyc +0 -0
- inference/full_precision/__pycache__/chat.cpython-311.pyc +0 -0
- inference/full_precision/__pycache__/deepseek_service.cpython-311.pyc +0 -0
- inference/full_precision/__pycache__/demo.cpython-311.pyc +0 -0
- inference/full_precision/__pycache__/infer.cpython-311.pyc +0 -0
- inference/full_precision/__pycache__/model_utils.cpython-311.pyc +0 -0
- inference/{app.py → full_precision/app.py} +162 -256
- inference/{chat.py → full_precision/chat.py} +38 -35
- inference/{deepseek_service.py → full_precision/deepseek_service.py} +86 -199
- inference/full_precision/demo.py +41 -0
- inference/full_precision/infer.py +54 -0
- inference/{model_utils.py → full_precision/model_utils.py} +103 -57
- inference/full_precision/run_api.sh +6 -0
- inference/full_precision/run_chat.sh +6 -0
- inference/full_precision/run_infer.sh +6 -0
- inference/inference.py +0 -43
- inference/int4_quantized/__init__.py +1 -0
- inference/int4_quantized/__pycache__/app.cpython-311.pyc +0 -0
- inference/int4_quantized/__pycache__/chat.cpython-311.pyc +0 -0
- inference/int4_quantized/__pycache__/infer.cpython-311.pyc +0 -0
- inference/int4_quantized/__pycache__/model_utils.cpython-311.pyc +0 -0
- inference/{.ipynb_checkpoints/app-checkpoint.py → int4_quantized/app.py} +181 -260
- inference/{.ipynb_checkpoints/chat-checkpoint.py → int4_quantized/chat.py} +37 -38
- inference/int4_quantized/infer.py +82 -0
- inference/int4_quantized/model_utils.py +538 -0
- inference/int4_quantized/run_api.sh +6 -0
- inference/int4_quantized/run_chat.sh +6 -0
- inference/int4_quantized/run_infer.sh +6 -0
- inference/int4_quantized/test_single.sh +6 -0
- inference/temp_uploads/.ipynb_checkpoints/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef-checkpoint.jpg +0 -0
- inference/temp_uploads/.ipynb_checkpoints/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c-checkpoint.jpg +0 -0
- inference/temp_uploads/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef.jpg +0 -0
- inference/temp_uploads/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c.jpg +0 -0
- requirements.txt +5 -1
LICENSE
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SkinGPT-R1 is released under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International license.
|
| 2 |
+
|
| 3 |
+
License summary:
|
| 4 |
+
- Attribution required
|
| 5 |
+
- Non-commercial use only
|
| 6 |
+
- Share adaptations under the same license
|
| 7 |
+
|
| 8 |
+
Full license text:
|
| 9 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
README.md
CHANGED
|
@@ -10,94 +10,147 @@ tags:
|
|
| 10 |
|
| 11 |
# SkinGPT-R1
|
| 12 |
|
| 13 |
-
**
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
##
|
| 20 |
|
| 21 |
-
|
| 22 |
|
| 23 |
-
##
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
```bash
|
| 28 |
conda create -n skingpt-r1 python=3.10 -y
|
| 29 |
conda activate skingpt-r1
|
|
|
|
| 30 |
```
|
| 31 |
|
| 32 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
```bash
|
| 35 |
-
|
| 36 |
```
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
```bash
|
| 41 |
-
|
| 42 |
```
|
| 43 |
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
```bash
|
| 53 |
-
|
| 54 |
```
|
| 55 |
|
| 56 |
-
|
| 57 |
|
| 58 |
-
To have a multi-turn conversation (e.g., asking follow-up questions about the diagnosis) in your terminal:
|
| 59 |
```bash
|
| 60 |
-
|
| 61 |
```
|
| 62 |
-
### FastAPI Backend Deployment
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
```bash
|
| 69 |
-
|
| 70 |
```
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
#
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
)
|
| 95 |
-
print("AI:", response.json()["message"])
|
| 96 |
-
|
| 97 |
-
# 3. Ask Follow-up
|
| 98 |
-
response = requests.post(
|
| 99 |
-
f"{API_URL}/v1/predict/{STATE_ID}",
|
| 100 |
-
json={"message": "What treatment do you recommend?"}
|
| 101 |
-
)
|
| 102 |
-
print("AI:", response.json()["message"])
|
| 103 |
-
```
|
|
|
|
| 10 |
|
| 11 |
# SkinGPT-R1
|
| 12 |
|
| 13 |
+
**Update:** We will soon release the **SkinGPT-R1-7B** weights.
|
| 14 |
|
| 15 |
+
SkinGPT-R1 is a dermatological reasoning vision language model for research and education.
|
| 16 |
|
| 17 |
+
From **The Chinese University of Hong Kong, Shenzhen (CUHKSZ)**.
|
| 18 |
|
| 19 |
+
## Disclaimer
|
| 20 |
|
| 21 |
+
This project is for **research and educational use only**. It is **not** a substitute for professional medical advice, diagnosis, or treatment.
|
| 22 |
|
| 23 |
+
## License
|
| 24 |
|
| 25 |
+
This repository is released under **CC BY-NC-SA 4.0**.
|
| 26 |
+
See [LICENSE](/Users/smac/Documents/SkinGPT-R1/LICENSE) for details.
|
| 27 |
+
|
| 28 |
+
## Structure
|
| 29 |
+
|
| 30 |
+
```text
|
| 31 |
+
SkinGPT-R1/
|
| 32 |
+
├── checkpoints/
|
| 33 |
+
├── inference/
|
| 34 |
+
│ ├── full_precision/
|
| 35 |
+
│ └── int4_quantized/
|
| 36 |
+
├── requirements.txt
|
| 37 |
+
└── README.md
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Checkpoint paths:
|
| 41 |
+
|
| 42 |
+
- Full precision: `./checkpoints/full_precision`
|
| 43 |
+
- INT4 quantized: `./checkpoints/int4`
|
| 44 |
+
|
| 45 |
+
## Install
|
| 46 |
|
| 47 |
```bash
|
| 48 |
conda create -n skingpt-r1 python=3.10 -y
|
| 49 |
conda activate skingpt-r1
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
```
|
| 52 |
|
| 53 |
+
## Attention Backend Notes
|
| 54 |
+
|
| 55 |
+
This repo uses two attention acceleration paths:
|
| 56 |
+
|
| 57 |
+
- `flash_attention_2`: external package, optional
|
| 58 |
+
- `sdpa`: PyTorch native scaled dot product attention
|
| 59 |
+
|
| 60 |
+
Recommended choice for this repo:
|
| 61 |
+
|
| 62 |
+
- RTX 50 series: use `sdpa`
|
| 63 |
+
- A100 / RTX 3090 / RTX 4090 / H100 and other GPUs explicitly listed by the FlashAttention project: you can try `flash_attention_2`
|
| 64 |
+
|
| 65 |
+
Practical notes:
|
| 66 |
+
|
| 67 |
+
- The current repo pins `torch==2.4.0`, and SDPA is already built into PyTorch in this version.
|
| 68 |
+
- FlashAttention's official README currently lists Ampere, Ada, and Hopper support for FlashAttention-2. It does not list RTX 50 / Blackwell consumer GPUs in that section, so this repo defaults to `sdpa` for that path.
|
| 69 |
+
- PyTorch 2.5 added a newer cuDNN SDPA backend for H100-class or newer GPUs, but this repo is pinned to PyTorch 2.4, so you should not assume those 2.5-specific gains here.
|
| 70 |
+
|
| 71 |
+
If you are on an RTX 5090 and `flash-attn` is unavailable or unstable in your environment, use the INT4 path in this repo, which is already configured with `attn_implementation="sdpa"`.
|
| 72 |
+
|
| 73 |
+
## Usage
|
| 74 |
+
|
| 75 |
+
### Full Precision
|
| 76 |
+
|
| 77 |
+
Single image:
|
| 78 |
|
| 79 |
```bash
|
| 80 |
+
bash inference/full_precision/run_infer.sh --image ./test_images/lesion.jpg
|
| 81 |
```
|
| 82 |
|
| 83 |
+
Multi-turn chat:
|
| 84 |
|
| 85 |
```bash
|
| 86 |
+
bash inference/full_precision/run_chat.sh --image ./test_images/lesion.jpg
|
| 87 |
```
|
| 88 |
|
| 89 |
+
API service:
|
| 90 |
|
| 91 |
+
```bash
|
| 92 |
+
bash inference/full_precision/run_api.sh
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Default API port: `5900`
|
| 96 |
|
| 97 |
+
### INT4 Quantized
|
| 98 |
|
| 99 |
+
Single image:
|
| 100 |
|
| 101 |
```bash
|
| 102 |
+
bash inference/int4_quantized/run_infer.sh --image_path ./test_images/lesion.jpg
|
| 103 |
```
|
| 104 |
|
| 105 |
+
Multi-turn chat:
|
| 106 |
|
|
|
|
| 107 |
```bash
|
| 108 |
+
bash inference/int4_quantized/run_chat.sh --image ./test_images/lesion.jpg
|
| 109 |
```
|
|
|
|
| 110 |
|
| 111 |
+
API service:
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
bash inference/int4_quantized/run_api.sh
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
Default API port: `5901`
|
| 118 |
+
|
| 119 |
+
The INT4 path uses:
|
| 120 |
+
|
| 121 |
+
- `bitsandbytes` 4-bit quantization
|
| 122 |
+
- `attn_implementation="sdpa"`
|
| 123 |
+
- the adapter-aware quantized model implementation in `inference/int4_quantized/`
|
| 124 |
|
| 125 |
+
## GPU Selection
|
| 126 |
+
|
| 127 |
+
You do not need to add `CUDA_VISIBLE_DEVICES=0` if the machine has only one visible GPU or if you are fine with the default CUDA device.
|
| 128 |
+
|
| 129 |
+
Use it only when you want to pin the process to a specific GPU, for example on a multi-GPU server:
|
| 130 |
|
| 131 |
```bash
|
| 132 |
+
CUDA_VISIBLE_DEVICES=0 bash inference/int4_quantized/run_infer.sh --image_path ./test_images/lesion.jpg
|
| 133 |
```
|
| 134 |
+
|
| 135 |
+
The same pattern also works for:
|
| 136 |
+
|
| 137 |
+
- `inference/full_precision/run_infer.sh`
|
| 138 |
+
- `inference/full_precision/run_chat.sh`
|
| 139 |
+
- `inference/full_precision/run_api.sh`
|
| 140 |
+
- `inference/int4_quantized/run_chat.sh`
|
| 141 |
+
- `inference/int4_quantized/run_api.sh`
|
| 142 |
+
|
| 143 |
+
## API Endpoints
|
| 144 |
+
|
| 145 |
+
Both API services expose the same endpoints:
|
| 146 |
+
|
| 147 |
+
- `POST /v1/upload/{state_id}`
|
| 148 |
+
- `POST /v1/predict/{state_id}`
|
| 149 |
+
- `POST /v1/reset/{state_id}`
|
| 150 |
+
- `POST /diagnose/stream`
|
| 151 |
+
- `GET /health`
|
| 152 |
+
|
| 153 |
+
## Which One To Use
|
| 154 |
+
|
| 155 |
+
- Use `full_precision` when you want the original model path and best fidelity.
|
| 156 |
+
- Use `int4_quantized` when GPU memory is tight or when you are on an environment where `flash-attn` is not the practical option.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/.ipynb_checkpoints/deepseek_service-checkpoint.py
DELETED
|
@@ -1,384 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
DeepSeek API Service
|
| 3 |
-
Used to optimize and organize SkinGPT model output results
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import re
|
| 8 |
-
from typing import Optional
|
| 9 |
-
from openai import AsyncOpenAI
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class DeepSeekService:
|
| 13 |
-
"""DeepSeek API Service Class"""
|
| 14 |
-
|
| 15 |
-
def __init__(self, api_key: Optional[str] = None):
|
| 16 |
-
"""
|
| 17 |
-
Initialize DeepSeek service
|
| 18 |
-
|
| 19 |
-
Parameters:
|
| 20 |
-
api_key: DeepSeek API key, reads from environment variable if not provided
|
| 21 |
-
"""
|
| 22 |
-
self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
|
| 23 |
-
self.base_url = "https://api.deepseek.com"
|
| 24 |
-
self.model = "deepseek-chat" # Using deepseek-chat model
|
| 25 |
-
|
| 26 |
-
self.client = None
|
| 27 |
-
self.is_loaded = False
|
| 28 |
-
|
| 29 |
-
print(f"DeepSeek API service initializing...")
|
| 30 |
-
print(f"API Base URL: {self.base_url}")
|
| 31 |
-
|
| 32 |
-
async def load(self):
|
| 33 |
-
"""Initialize DeepSeek API client"""
|
| 34 |
-
try:
|
| 35 |
-
if not self.api_key:
|
| 36 |
-
print("DeepSeek API key not provided")
|
| 37 |
-
self.is_loaded = False
|
| 38 |
-
return
|
| 39 |
-
|
| 40 |
-
# Initialize OpenAI compatible client
|
| 41 |
-
self.client = AsyncOpenAI(
|
| 42 |
-
api_key=self.api_key,
|
| 43 |
-
base_url=self.base_url
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
self.is_loaded = True
|
| 47 |
-
print("DeepSeek API service is ready!")
|
| 48 |
-
|
| 49 |
-
except Exception as e:
|
| 50 |
-
print(f"DeepSeek API service initialization failed: {e}")
|
| 51 |
-
self.is_loaded = False
|
| 52 |
-
|
| 53 |
-
async def refine_diagnosis(
|
| 54 |
-
self,
|
| 55 |
-
raw_answer: str,
|
| 56 |
-
raw_thinking: Optional[str] = None,
|
| 57 |
-
language: str = "zh"
|
| 58 |
-
) -> dict:
|
| 59 |
-
"""
|
| 60 |
-
Use DeepSeek API to optimize and organize diagnosis results
|
| 61 |
-
|
| 62 |
-
Parameters:
|
| 63 |
-
raw_answer: Original diagnosis result
|
| 64 |
-
raw_thinking: AI thinking process
|
| 65 |
-
language: Language option
|
| 66 |
-
|
| 67 |
-
Returns:
|
| 68 |
-
Dictionary containing "description", "analysis_process" and "diagnosis_result"
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
if not self.is_loaded or self.client is None:
|
| 72 |
-
error_msg = "API not initialized, cannot generate analysis" if language == "en" else "API未初始化,无法生成分析过程"
|
| 73 |
-
print("DeepSeek API not initialized, returning original result")
|
| 74 |
-
return {
|
| 75 |
-
"success": False,
|
| 76 |
-
"description": "",
|
| 77 |
-
"analysis_process": raw_thinking or error_msg,
|
| 78 |
-
"diagnosis_result": raw_answer,
|
| 79 |
-
"original_diagnosis": raw_answer,
|
| 80 |
-
"error": "DeepSeek API not initialized"
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
try:
|
| 84 |
-
# Build prompt
|
| 85 |
-
prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
|
| 86 |
-
|
| 87 |
-
# Select system prompt based on language
|
| 88 |
-
if language == "en":
|
| 89 |
-
system_content = "You are a professional medical text editor. Your task is to polish and organize medical diagnostic text to make it flow smoothly while preserving the original meaning. Output ONLY the formatted result. Do NOT add any explanations, comments, or thoughts. Just follow the format exactly."
|
| 90 |
-
else:
|
| 91 |
-
system_content = "你是医学文本整理专家,按照用户要求将用户输入的文本整理成用户想要的格式,不要改写或总结。"
|
| 92 |
-
|
| 93 |
-
# Call DeepSeek API
|
| 94 |
-
response = await self.client.chat.completions.create(
|
| 95 |
-
model=self.model,
|
| 96 |
-
messages=[
|
| 97 |
-
{"role": "system", "content": system_content},
|
| 98 |
-
{"role": "user", "content": prompt}
|
| 99 |
-
],
|
| 100 |
-
temperature=0.1,
|
| 101 |
-
max_tokens=2048,
|
| 102 |
-
top_p=0.8,
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
# Extract generated text
|
| 106 |
-
generated_text = response.choices[0].message.content
|
| 107 |
-
|
| 108 |
-
# Parse output
|
| 109 |
-
parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
|
| 110 |
-
|
| 111 |
-
return {
|
| 112 |
-
"success": True,
|
| 113 |
-
"description": parsed["description"],
|
| 114 |
-
"analysis_process": parsed["analysis_process"],
|
| 115 |
-
"diagnosis_result": parsed["diagnosis_result"],
|
| 116 |
-
"original_diagnosis": raw_answer,
|
| 117 |
-
"raw_refined": generated_text
|
| 118 |
-
}
|
| 119 |
-
|
| 120 |
-
except Exception as e:
|
| 121 |
-
print(f"DeepSeek API call failed: {e}")
|
| 122 |
-
error_msg = "API call failed, cannot generate analysis" if language == "en" else "API调用失败,无法生成分析过程"
|
| 123 |
-
return {
|
| 124 |
-
"success": False,
|
| 125 |
-
"description": "",
|
| 126 |
-
"analysis_process": raw_thinking or error_msg,
|
| 127 |
-
"diagnosis_result": raw_answer,
|
| 128 |
-
"original_diagnosis": raw_answer,
|
| 129 |
-
"error": str(e)
|
| 130 |
-
}
|
| 131 |
-
|
| 132 |
-
def _build_refine_prompt(self, raw_answer: str, raw_thinking: Optional[str] = None, language: str = "zh") -> str:
|
| 133 |
-
"""
|
| 134 |
-
Build optimization prompt
|
| 135 |
-
|
| 136 |
-
Parameters:
|
| 137 |
-
raw_answer: Original diagnosis result
|
| 138 |
-
raw_thinking: AI thinking process
|
| 139 |
-
language: Language option, "zh" for Chinese, "en" for English
|
| 140 |
-
|
| 141 |
-
Returns:
|
| 142 |
-
Built prompt
|
| 143 |
-
"""
|
| 144 |
-
if language == "en":
|
| 145 |
-
# English prompt - organize and polish while preserving meaning
|
| 146 |
-
thinking_text = raw_thinking if raw_thinking else "No analysis process available."
|
| 147 |
-
prompt = f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
|
| 148 |
-
|
| 149 |
-
【Requirements】
|
| 150 |
-
- Preserve the original tone and expression style
|
| 151 |
-
- Text 1 contains the thinking process, Text 2 contains the diagnosis result
|
| 152 |
-
- Extract the image observation part from the thinking process as Description. This should include all factual observations about what was seen in the image, not just a brief summary.
|
| 153 |
-
- For Diagnostic Reasoning: refine and condense the remaining thinking content. Remove redundancies, self-doubt, circular reasoning, and unnecessary repetition. Keep it concise and not too long. Keep the logical chain clear and enhance readability. IMPORTANT: DO NOT include any image description or visual observations in Diagnostic Reasoning. Only include reasoning, analysis, and diagnostic thought process.
|
| 154 |
-
- If [Text 1] content is NOT: No analysis process available. Then organize [Text 1] content accordingly, DO NOT confuse [Text 1] and [Text 2]
|
| 155 |
-
- If [Text 1] content IS: No analysis process available. Then extract the analysis process and description from [Text 2]
|
| 156 |
-
- DO NOT infer or add new medical information, DO NOT output any meta-commentary
|
| 157 |
-
- You may adjust unreasonable statements or remove redundant content to improve clarity
|
| 158 |
-
|
| 159 |
-
[Text 1]
|
| 160 |
-
{thinking_text}
|
| 161 |
-
|
| 162 |
-
[Text 2]
|
| 163 |
-
{raw_answer}
|
| 164 |
-
|
| 165 |
-
【Output】Only output three sections, do not output anything else:
|
| 166 |
-
## Description
|
| 167 |
-
(Extract all image observation content from the thinking process - include all factual descriptions of what was seen)
|
| 168 |
-
|
| 169 |
-
## Analysis Process
|
| 170 |
-
(Refined and condensed diagnostic reasoning: remove self-doubt, circular logic, and redundancies. Keep it concise and not too long. Keep logical flow clear. Do NOT include image observations)
|
| 171 |
-
|
| 172 |
-
## Diagnosis Result
|
| 173 |
-
(The organized diagnosis result from Text 2)
|
| 174 |
-
|
| 175 |
-
【Example】:
|
| 176 |
-
## Description
|
| 177 |
-
The image shows red inflamed patches on the skin with pustules and darker colored spots. The lesions appear as papules and pustules distributed across the affected area, with some showing signs of inflammation and possible post-inflammatory hyperpigmentation.
|
| 178 |
-
|
| 179 |
-
## Analysis Process
|
| 180 |
-
These findings are consistent with acne vulgaris, commonly seen during adolescence. The user's age aligns with typical onset for this condition. Treatment recommendations: over-the-counter medications such as benzoyl peroxide or topical antibiotics, avoiding picking at the skin, and consulting a dermatologist if severe. The goal is to control inflammation and prevent scarring.
|
| 181 |
-
|
| 182 |
-
## Diagnosis Result
|
| 183 |
-
Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition, especially during adolescence, when hormonal changes cause overactive sebaceous glands, which can easily clog pores and form acne. Pathological care recommendations: 1. Keep face clean, wash face 2-3 times daily, use gentle cleansing products. 2. Avoid squeezing acne with hands to prevent worsening inflammation or leaving scars. 3. Avoid using irritating cosmetics and skincare products. 4. Can use topical medications containing salicylic acid, benzoyl peroxide, etc. 5. If necessary, can use oral antibiotics or other treatment methods under doctor's guidance. Precautions: 1. Avoid rubbing or damaging the affected area to prevent infection. 2. Eat less oily and spicy foods, eat more vegetables and fruits. 3. Maintain good rest habits, avoid staying up late. 4. If acne symptoms persist without improvement or show signs of worsening, seek medical attention promptly.
|
| 184 |
-
"""
|
| 185 |
-
else:
|
| 186 |
-
# Chinese prompt - translate to Simplified Chinese AND organize/polish
|
| 187 |
-
thinking_text = raw_thinking if raw_thinking else "No analysis process available."
|
| 188 |
-
prompt = f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
|
| 189 |
-
|
| 190 |
-
【要求】
|
| 191 |
-
- 保留原文的语气和表达方式
|
| 192 |
-
- 文本1是思考过程,文本2是诊断结果
|
| 193 |
-
- 从思考过程中提取图像观察部分作为图像描述。需要包含所有关于图片中观察到的事实内容,不要简化或缩短。
|
| 194 |
-
- 对于分析过程:提炼并精简剩余的思考内容,去除冗余、自我怀疑、兜圈子的内容。保持简洁,不要太长。保持逻辑链条清晰,增强可读性。重要:分析过程中不���包含任何图像描述或视觉观察内容,只包含推理、分析和诊断思考过程。
|
| 195 |
-
- 如果【文本1】内容不是:No analysis process available.那么按要求整理【文本1】的内容,不要混淆【文本1】和【文本2】。
|
| 196 |
-
- 如果【文本1】内容是:No analysis process available.那么从【文本2】提炼分析过程和描述。
|
| 197 |
-
- 【文本1】和【文本2】需要翻译成简体中文
|
| 198 |
-
- 禁止推断或添加新的医学信息,禁止输出任何元评论
|
| 199 |
-
- 可以调整不合理的语句或去除冗余内容以提高清晰度
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
【文本1】
|
| 203 |
-
{thinking_text}
|
| 204 |
-
|
| 205 |
-
【文本2】
|
| 206 |
-
{raw_answer}
|
| 207 |
-
|
| 208 |
-
【输出】只输出三个部分,不要输出其他任何内容:
|
| 209 |
-
## 图像描述
|
| 210 |
-
(从思考过程中提取所有图像观察内容,包含所有关于图片的事实描述)
|
| 211 |
-
|
| 212 |
-
## 分析过程
|
| 213 |
-
(提炼并精简后的诊断推理:去除自我怀疑、兜圈逻辑和冗余内容。保持简洁,不要太长。保持逻辑流畅。不包含图像观察)
|
| 214 |
-
|
| 215 |
-
## 诊断结果
|
| 216 |
-
(整理后的诊断结果)
|
| 217 |
-
|
| 218 |
-
【样例】:
|
| 219 |
-
## 图像描述
|
| 220 |
-
图片显示皮肤上有红色发炎的斑块,伴有脓疱和颜色较深的斑点。病变表现为分布在受影响区域的丘疹和脓疱,部分显示出炎症迹象和可能的炎症后色素沉着。
|
| 221 |
-
|
| 222 |
-
## 分析过程
|
| 223 |
-
这些表现符合寻常痤疮的特征,青春期常见。用户的年龄与该病症的典型发病年龄相符。治疗建议:使用非处方药物如过氧化苯甲酰或外用抗生素,避免抠抓皮肤,病情严重时咨询皮肤科医生。目标是控制炎症并防止疤痕形成。
|
| 224 |
-
|
| 225 |
-
## 诊断结果
|
| 226 |
-
可能的诊断:痤疮(青春痘) 解释:痤疮是一种常见的皮肤病,特别是在青少年期间,由于激素水平的变化导致皮脂腺过度活跃,容易堵塞毛孔,形成痤疮。 病理护理建议:1.保持面部清洁,每天洗脸2-3次,使用温和的洁面产品。 2.避免用手挤压痤疮,以免加重炎症或留下疤痕。 3.避免使用刺激性的化妆品和护肤品。 4.可以使用含有水杨酸、苯氧醇等成分的外用药物治疗。 5.如有需要,可以在医生指导下使用抗生素口服药或其他治疗方法。 注意事项:1. 避免摩擦或损伤患处,以免引起感染。 2. 饮食上应少吃油腻、辛辣食物,多吃蔬菜水果。 3. 保持良好的作息习惯,避免熬夜。 4. 如果痤疮症状持续不见好转或有恶化的趋势,应及时就医。
|
| 227 |
-
"""
|
| 228 |
-
|
| 229 |
-
return prompt
|
| 230 |
-
|
| 231 |
-
def _parse_refined_output(
|
| 232 |
-
self,
|
| 233 |
-
generated_text: str,
|
| 234 |
-
raw_answer: str,
|
| 235 |
-
raw_thinking: Optional[str] = None,
|
| 236 |
-
language: str = "zh"
|
| 237 |
-
) -> dict:
|
| 238 |
-
"""
|
| 239 |
-
Parse DeepSeek generated output
|
| 240 |
-
|
| 241 |
-
Parameters:
|
| 242 |
-
generated_text: DeepSeek generated text
|
| 243 |
-
raw_answer: Original diagnosis (as fallback)
|
| 244 |
-
raw_thinking: Original thinking process (as fallback)
|
| 245 |
-
language: Language option
|
| 246 |
-
|
| 247 |
-
Returns:
|
| 248 |
-
Dictionary containing description, analysis_process and diagnosis_result
|
| 249 |
-
"""
|
| 250 |
-
description = ""
|
| 251 |
-
analysis_process = None
|
| 252 |
-
diagnosis_result = None
|
| 253 |
-
|
| 254 |
-
if language == "en":
|
| 255 |
-
# English patterns
|
| 256 |
-
desc_match = re.search(
|
| 257 |
-
r'##\s*Description\s*\n([\s\S]*?)(?=##\s*Analysis\s*Process|$)',
|
| 258 |
-
generated_text,
|
| 259 |
-
re.IGNORECASE
|
| 260 |
-
)
|
| 261 |
-
analysis_match = re.search(
|
| 262 |
-
r'##\s*Analysis\s*Process\s*\n([\s\S]*?)(?=##\s*Diagnosis\s*Result|$)',
|
| 263 |
-
generated_text,
|
| 264 |
-
re.IGNORECASE
|
| 265 |
-
)
|
| 266 |
-
result_match = re.search(
|
| 267 |
-
r'##\s*Diagnosis\s*Result\s*\n([\s\S]*?)$',
|
| 268 |
-
generated_text,
|
| 269 |
-
re.IGNORECASE
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
desc_header = "## Description"
|
| 273 |
-
analysis_header = "## Analysis Process"
|
| 274 |
-
result_header = "## Diagnosis Result"
|
| 275 |
-
else:
|
| 276 |
-
# Chinese patterns
|
| 277 |
-
desc_match = re.search(
|
| 278 |
-
r'##\s*图像描述\s*\n([\s\S]*?)(?=##\s*分析过程|$)',
|
| 279 |
-
generated_text
|
| 280 |
-
)
|
| 281 |
-
analysis_match = re.search(
|
| 282 |
-
r'##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)',
|
| 283 |
-
generated_text
|
| 284 |
-
)
|
| 285 |
-
result_match = re.search(
|
| 286 |
-
r'##\s*诊断结果\s*\n([\s\S]*?)$',
|
| 287 |
-
generated_text
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
desc_header = "## 图像描述"
|
| 291 |
-
analysis_header = "## 分析过程"
|
| 292 |
-
result_header = "## 诊断结果"
|
| 293 |
-
|
| 294 |
-
# Extract description
|
| 295 |
-
if desc_match:
|
| 296 |
-
description = desc_match.group(1).strip()
|
| 297 |
-
print(f"Successfully parsed description")
|
| 298 |
-
else:
|
| 299 |
-
print(f"Description parsing failed")
|
| 300 |
-
description = ""
|
| 301 |
-
|
| 302 |
-
# Extract analysis process
|
| 303 |
-
if analysis_match:
|
| 304 |
-
analysis_process = analysis_match.group(1).strip()
|
| 305 |
-
print(f"Successfully parsed analysis process")
|
| 306 |
-
else:
|
| 307 |
-
print(f"Analysis process parsing failed, trying other methods")
|
| 308 |
-
# Try to extract from generated text
|
| 309 |
-
result_pos = generated_text.find(result_header)
|
| 310 |
-
if result_pos > 0:
|
| 311 |
-
# Get content before diagnosis result
|
| 312 |
-
analysis_process = generated_text[:result_pos].strip()
|
| 313 |
-
# Remove possible headers
|
| 314 |
-
for header in [desc_header, analysis_header]:
|
| 315 |
-
header_escaped = re.escape(header)
|
| 316 |
-
analysis_process = re.sub(f'{header_escaped}\\s*\\n?', '', analysis_process).strip()
|
| 317 |
-
else:
|
| 318 |
-
# If no format at all, try to get first half
|
| 319 |
-
mid_point = len(generated_text) // 2
|
| 320 |
-
analysis_process = generated_text[:mid_point].strip()
|
| 321 |
-
|
| 322 |
-
# If still empty, use original content (final fallback)
|
| 323 |
-
if not analysis_process and raw_thinking:
|
| 324 |
-
print(f"Using original raw_thinking as fallback")
|
| 325 |
-
analysis_process = raw_thinking
|
| 326 |
-
|
| 327 |
-
# Extract diagnosis result
|
| 328 |
-
if result_match:
|
| 329 |
-
diagnosis_result = result_match.group(1).strip()
|
| 330 |
-
print(f"Successfully parsed diagnosis result")
|
| 331 |
-
else:
|
| 332 |
-
print(f"Diagnosis result parsing failed, trying other methods")
|
| 333 |
-
# Try to extract from generated text
|
| 334 |
-
result_pos = generated_text.find(result_header)
|
| 335 |
-
if result_pos > 0:
|
| 336 |
-
diagnosis_result = generated_text[result_pos:].strip()
|
| 337 |
-
# Remove possible header
|
| 338 |
-
result_header_escaped = re.escape(result_header)
|
| 339 |
-
diagnosis_result = re.sub(f'^{result_header_escaped}\\s*\\n?', '', diagnosis_result).strip()
|
| 340 |
-
else:
|
| 341 |
-
# If no format at all, get second half
|
| 342 |
-
mid_point = len(generated_text) // 2
|
| 343 |
-
diagnosis_result = generated_text[mid_point:].strip()
|
| 344 |
-
|
| 345 |
-
# If still empty, use original content (final fallback)
|
| 346 |
-
if not diagnosis_result:
|
| 347 |
-
print(f"Using original raw_answer as fallback")
|
| 348 |
-
diagnosis_result = raw_answer
|
| 349 |
-
|
| 350 |
-
return {
|
| 351 |
-
"description": description,
|
| 352 |
-
"analysis_process": analysis_process,
|
| 353 |
-
"diagnosis_result": diagnosis_result
|
| 354 |
-
}
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
# Global DeepSeek service instance (lazy loading)
|
| 358 |
-
_deepseek_service: Optional[DeepSeekService] = None
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
|
| 362 |
-
"""
|
| 363 |
-
Get DeepSeek service instance (singleton pattern)
|
| 364 |
-
|
| 365 |
-
Parameters:
|
| 366 |
-
api_key: Optional API key to use
|
| 367 |
-
|
| 368 |
-
Returns:
|
| 369 |
-
DeepSeekService instance, or None if API initialization fails
|
| 370 |
-
"""
|
| 371 |
-
global _deepseek_service
|
| 372 |
-
|
| 373 |
-
if _deepseek_service is None:
|
| 374 |
-
try:
|
| 375 |
-
_deepseek_service = DeepSeekService(api_key=api_key)
|
| 376 |
-
await _deepseek_service.load()
|
| 377 |
-
if not _deepseek_service.is_loaded:
|
| 378 |
-
print("DeepSeek API service initialization failed, will use fallback mode")
|
| 379 |
-
return _deepseek_service # Return instance but marked as not loaded
|
| 380 |
-
except Exception as e:
|
| 381 |
-
print(f"DeepSeek service initialization failed: {e}")
|
| 382 |
-
return None
|
| 383 |
-
|
| 384 |
-
return _deepseek_service
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/.ipynb_checkpoints/demo-checkpoint.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 3 |
-
from qwen_vl_utils import process_vision_info
|
| 4 |
-
from PIL import Image
|
| 5 |
-
|
| 6 |
-
# === Configuration ===
|
| 7 |
-
MODEL_PATH = "../checkpoint"
|
| 8 |
-
IMAGE_PATH = "test_image.jpg" # Please replace with your actual image path
|
| 9 |
-
PROMPT = "You are a professional AI dermatology assistant. Please analyze this skin image and provide a diagnosis."
|
| 10 |
-
|
| 11 |
-
def main():
|
| 12 |
-
print(f"Loading model from {MODEL_PATH}...")
|
| 13 |
-
|
| 14 |
-
# 1. Load Model
|
| 15 |
-
try:
|
| 16 |
-
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 17 |
-
MODEL_PATH,
|
| 18 |
-
torch_dtype=torch.bfloat16,
|
| 19 |
-
device_map="auto",
|
| 20 |
-
trust_remote_code=True
|
| 21 |
-
)
|
| 22 |
-
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
| 23 |
-
except Exception as e:
|
| 24 |
-
print(f"Error loading model: {e}")
|
| 25 |
-
return
|
| 26 |
-
|
| 27 |
-
# 2. Check Image
|
| 28 |
-
import os
|
| 29 |
-
if not os.path.exists(IMAGE_PATH):
|
| 30 |
-
print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
|
| 31 |
-
# Create a dummy image for code demonstration purposes if needed, or just return
|
| 32 |
-
return
|
| 33 |
-
|
| 34 |
-
# 3. Prepare Inputs
|
| 35 |
-
messages = [
|
| 36 |
-
{
|
| 37 |
-
"role": "user",
|
| 38 |
-
"content": [
|
| 39 |
-
{"type": "image", "image": IMAGE_PATH},
|
| 40 |
-
{"type": "text", "text": PROMPT},
|
| 41 |
-
],
|
| 42 |
-
}
|
| 43 |
-
]
|
| 44 |
-
|
| 45 |
-
print("Processing...")
|
| 46 |
-
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 47 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 48 |
-
|
| 49 |
-
inputs = processor(
|
| 50 |
-
text=[text],
|
| 51 |
-
images=image_inputs,
|
| 52 |
-
videos=video_inputs,
|
| 53 |
-
padding=True,
|
| 54 |
-
return_tensors="pt",
|
| 55 |
-
).to(model.device)
|
| 56 |
-
|
| 57 |
-
# 4. Generate
|
| 58 |
-
with torch.no_grad():
|
| 59 |
-
generated_ids = model.generate(
|
| 60 |
-
**inputs,
|
| 61 |
-
max_new_tokens=1024,
|
| 62 |
-
temperature=0.7,
|
| 63 |
-
top_p=0.9
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
# 5. Decode
|
| 67 |
-
output_text = processor.batch_decode(
|
| 68 |
-
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
print("\n=== Diagnosis Result ===")
|
| 72 |
-
print(output_text[0])
|
| 73 |
-
print("========================")
|
| 74 |
-
|
| 75 |
-
if __name__ == "__main__":
|
| 76 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/.ipynb_checkpoints/inference-checkpoint.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import argparse
|
| 3 |
-
from model_utils import SkinGPTModel
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
def main():
|
| 7 |
-
parser = argparse.ArgumentParser(description="SkinGPT-R1 Single Inference")
|
| 8 |
-
parser.add_argument("--image", type=str, required=True, help="Path to the image")
|
| 9 |
-
parser.add_argument("--model_path", type=str, default="../checkpoint")
|
| 10 |
-
parser.add_argument("--prompt", type=str, default="Please analyze this skin image and provide a diagnosis.")
|
| 11 |
-
args = parser.parse_args()
|
| 12 |
-
|
| 13 |
-
if not os.path.exists(args.image):
|
| 14 |
-
print(f"Error: Image not found at {args.image}")
|
| 15 |
-
return
|
| 16 |
-
|
| 17 |
-
# 1. 加载模型 (复用 model_utils)
|
| 18 |
-
# 这样你就不用在这里重复写 transformers 的加载代码了
|
| 19 |
-
bot = SkinGPTModel(args.model_path)
|
| 20 |
-
|
| 21 |
-
# 2. 构造单轮消息
|
| 22 |
-
system_prompt = "You are a professional AI dermatology assistant."
|
| 23 |
-
messages = [
|
| 24 |
-
{
|
| 25 |
-
"role": "user",
|
| 26 |
-
"content": [
|
| 27 |
-
{"type": "image", "image": args.image},
|
| 28 |
-
{"type": "text", "text": f"{system_prompt}\n\n{args.prompt}"}
|
| 29 |
-
]
|
| 30 |
-
}
|
| 31 |
-
]
|
| 32 |
-
|
| 33 |
-
# 3. 推理
|
| 34 |
-
print(f"\nAnalyzing {args.image}...")
|
| 35 |
-
response = bot.generate_response(messages)
|
| 36 |
-
|
| 37 |
-
print("-" * 40)
|
| 38 |
-
print("Result:")
|
| 39 |
-
print(response)
|
| 40 |
-
print("-" * 40)
|
| 41 |
-
|
| 42 |
-
if __name__ == "__main__":
|
| 43 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/.ipynb_checkpoints/model_utils-checkpoint.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
# model_utils.py
|
| 2 |
-
import torch
|
| 3 |
-
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
|
| 4 |
-
from qwen_vl_utils import process_vision_info
|
| 5 |
-
from PIL import Image
|
| 6 |
-
import os
|
| 7 |
-
from threading import Thread
|
| 8 |
-
|
| 9 |
-
class SkinGPTModel:
|
| 10 |
-
def __init__(self, model_path, device=None):
|
| 11 |
-
self.model_path = model_path
|
| 12 |
-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
-
print(f"Loading model from {model_path} on {self.device}...")
|
| 14 |
-
|
| 15 |
-
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 16 |
-
model_path,
|
| 17 |
-
torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
|
| 18 |
-
attn_implementation="flash_attention_2" if self.device == "cuda" else None,
|
| 19 |
-
device_map="auto" if self.device != "mps" else None,
|
| 20 |
-
trust_remote_code=True
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
if self.device == "mps":
|
| 24 |
-
self.model = self.model.to(self.device)
|
| 25 |
-
|
| 26 |
-
self.processor = AutoProcessor.from_pretrained(
|
| 27 |
-
model_path,
|
| 28 |
-
trust_remote_code=True,
|
| 29 |
-
min_pixels=256*28*28,
|
| 30 |
-
max_pixels=1280*28*28
|
| 31 |
-
)
|
| 32 |
-
print("Model loaded successfully.")
|
| 33 |
-
|
| 34 |
-
def generate_response(self, messages, max_new_tokens=1024, temperature=0.7):
|
| 35 |
-
"""
|
| 36 |
-
处理多轮对话的历史消息列表并生成回复
|
| 37 |
-
messages format:
|
| 38 |
-
[
|
| 39 |
-
{'role': 'user', 'content': [{'type': 'image', 'image': 'path...'}, {'type': 'text', 'text': '...'}]},
|
| 40 |
-
{'role': 'assistant', 'content': [{'type': 'text', 'text': '...'}]}
|
| 41 |
-
]
|
| 42 |
-
"""
|
| 43 |
-
# 预处理文本模板
|
| 44 |
-
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 45 |
-
|
| 46 |
-
# 预处理视觉信息
|
| 47 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 48 |
-
|
| 49 |
-
inputs = self.processor(
|
| 50 |
-
text=[text],
|
| 51 |
-
images=image_inputs,
|
| 52 |
-
videos=video_inputs,
|
| 53 |
-
padding=True,
|
| 54 |
-
return_tensors="pt",
|
| 55 |
-
).to(self.model.device)
|
| 56 |
-
|
| 57 |
-
with torch.no_grad():
|
| 58 |
-
generated_ids = self.model.generate(
|
| 59 |
-
**inputs,
|
| 60 |
-
max_new_tokens=max_new_tokens,
|
| 61 |
-
temperature=temperature,
|
| 62 |
-
top_p=0.9,
|
| 63 |
-
do_sample=True
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
# 解码输出 (去除输入的token)
|
| 67 |
-
generated_ids_trimmed = [
|
| 68 |
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 69 |
-
]
|
| 70 |
-
output_text = self.processor.batch_decode(
|
| 71 |
-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
return output_text[0]
|
| 75 |
-
|
| 76 |
-
def generate_response_stream(self, messages, max_new_tokens=2048, temperature=0.7):
|
| 77 |
-
"""
|
| 78 |
-
流式生成响应
|
| 79 |
-
返回一个生成器,逐个yield生成的文本chunk
|
| 80 |
-
"""
|
| 81 |
-
# 预处理文本模板
|
| 82 |
-
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 83 |
-
|
| 84 |
-
# 预处理视觉信息
|
| 85 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 86 |
-
|
| 87 |
-
inputs = self.processor(
|
| 88 |
-
text=[text],
|
| 89 |
-
images=image_inputs,
|
| 90 |
-
videos=video_inputs,
|
| 91 |
-
padding=True,
|
| 92 |
-
return_tensors="pt",
|
| 93 |
-
).to(self.model.device)
|
| 94 |
-
|
| 95 |
-
# 创建 TextIteratorStreamer 用于流式输出
|
| 96 |
-
streamer = TextIteratorStreamer(
|
| 97 |
-
self.processor.tokenizer,
|
| 98 |
-
skip_prompt=True,
|
| 99 |
-
skip_special_tokens=True
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
# 准备生成参数
|
| 103 |
-
generation_kwargs = {
|
| 104 |
-
**inputs,
|
| 105 |
-
"max_new_tokens": max_new_tokens,
|
| 106 |
-
"temperature": temperature,
|
| 107 |
-
"top_p": 0.9,
|
| 108 |
-
"do_sample": True,
|
| 109 |
-
"streamer": streamer,
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
# 在单独的线程中运行生成
|
| 113 |
-
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
| 114 |
-
thread.start()
|
| 115 |
-
|
| 116 |
-
# 逐个yield生成的文本
|
| 117 |
-
for text_chunk in streamer:
|
| 118 |
-
yield text_chunk
|
| 119 |
-
|
| 120 |
-
thread.join()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference
|
| 2 |
+
|
| 3 |
+
Two runtime tracks are provided:
|
| 4 |
+
|
| 5 |
+
- `full_precision/`: single-image inference, multi-turn chat, and FastAPI service
|
| 6 |
+
- `int4_quantized/`: single-image inference, multi-turn chat, and FastAPI service for the INT4 path
|
| 7 |
+
|
| 8 |
+
Checkpoint paths:
|
| 9 |
+
|
| 10 |
+
- `./checkpoints/full_precision`
|
| 11 |
+
- `./checkpoints/int4`
|
inference/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Inference entrypoints for SkinGPT-R1."""
|
inference/__pycache__/app.cpython-311.pyc
DELETED
|
Binary file (17.8 kB)
|
|
|
inference/__pycache__/deepseek_service.cpython-311.pyc
DELETED
|
Binary file (18.3 kB)
|
|
|
inference/__pycache__/model_utils.cpython-311.pyc
DELETED
|
Binary file (5.39 kB)
|
|
|
inference/demo.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 3 |
-
from qwen_vl_utils import process_vision_info
|
| 4 |
-
from PIL import Image
|
| 5 |
-
|
| 6 |
-
# === Configuration ===
|
| 7 |
-
MODEL_PATH = "../checkpoint"
|
| 8 |
-
IMAGE_PATH = "test_image.jpg" # Please replace with your actual image path
|
| 9 |
-
PROMPT = "You are a professional AI dermatology assistant. Please analyze this skin image and provide a diagnosis."
|
| 10 |
-
|
| 11 |
-
def main():
|
| 12 |
-
print(f"Loading model from {MODEL_PATH}...")
|
| 13 |
-
|
| 14 |
-
# 1. Load Model
|
| 15 |
-
try:
|
| 16 |
-
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 17 |
-
MODEL_PATH,
|
| 18 |
-
torch_dtype=torch.bfloat16,
|
| 19 |
-
device_map="auto",
|
| 20 |
-
trust_remote_code=True
|
| 21 |
-
)
|
| 22 |
-
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
| 23 |
-
except Exception as e:
|
| 24 |
-
print(f"Error loading model: {e}")
|
| 25 |
-
return
|
| 26 |
-
|
| 27 |
-
# 2. Check Image
|
| 28 |
-
import os
|
| 29 |
-
if not os.path.exists(IMAGE_PATH):
|
| 30 |
-
print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
|
| 31 |
-
# Create a dummy image for code demonstration purposes if needed, or just return
|
| 32 |
-
return
|
| 33 |
-
|
| 34 |
-
# 3. Prepare Inputs
|
| 35 |
-
messages = [
|
| 36 |
-
{
|
| 37 |
-
"role": "user",
|
| 38 |
-
"content": [
|
| 39 |
-
{"type": "image", "image": IMAGE_PATH},
|
| 40 |
-
{"type": "text", "text": PROMPT},
|
| 41 |
-
],
|
| 42 |
-
}
|
| 43 |
-
]
|
| 44 |
-
|
| 45 |
-
print("Processing...")
|
| 46 |
-
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 47 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 48 |
-
|
| 49 |
-
inputs = processor(
|
| 50 |
-
text=[text],
|
| 51 |
-
images=image_inputs,
|
| 52 |
-
videos=video_inputs,
|
| 53 |
-
padding=True,
|
| 54 |
-
return_tensors="pt",
|
| 55 |
-
).to(model.device)
|
| 56 |
-
|
| 57 |
-
# 4. Generate
|
| 58 |
-
with torch.no_grad():
|
| 59 |
-
generated_ids = model.generate(
|
| 60 |
-
**inputs,
|
| 61 |
-
max_new_tokens=1024,
|
| 62 |
-
temperature=0.7,
|
| 63 |
-
repetition_penalty=1.2,
|
| 64 |
-
no_repeat_ngram_size=3,
|
| 65 |
-
top_p=0.9,
|
| 66 |
-
do_sample=True
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
# 5. Decode
|
| 70 |
-
output_text = processor.batch_decode(
|
| 71 |
-
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
print("\n=== Diagnosis Result ===")
|
| 75 |
-
print(output_text[0])
|
| 76 |
-
print("========================")
|
| 77 |
-
|
| 78 |
-
if __name__ == "__main__":
|
| 79 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/full_precision/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Full-precision inference package for SkinGPT-R1."""
|
inference/full_precision/__pycache__/app.cpython-311.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
inference/full_precision/__pycache__/chat.cpython-311.pyc
ADDED
|
Binary file (3.58 kB). View file
|
|
|
inference/full_precision/__pycache__/deepseek_service.cpython-311.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
inference/full_precision/__pycache__/demo.cpython-311.pyc
ADDED
|
Binary file (1.8 kB). View file
|
|
|
inference/full_precision/__pycache__/infer.cpython-311.pyc
ADDED
|
Binary file (2.63 kB). View file
|
|
|
inference/full_precision/__pycache__/model_utils.cpython-311.pyc
ADDED
|
Binary file (7.37 kB). View file
|
|
|
inference/{app.py → full_precision/app.py}
RENAMED
|
@@ -1,133 +1,81 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
import shutil
|
| 5 |
import uuid
|
| 6 |
-
import json
|
| 7 |
-
import re
|
| 8 |
-
import asyncio
|
| 9 |
-
from typing import Optional
|
| 10 |
-
from io import BytesIO
|
| 11 |
from contextlib import asynccontextmanager
|
| 12 |
-
from
|
| 13 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
from fastapi.responses import StreamingResponse
|
| 16 |
-
from
|
| 17 |
-
from model_utils import SkinGPTModel
|
| 18 |
-
from deepseek_service import get_deepseek_service, DeepSeekService
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
# Global DeepSeek service instance
|
| 29 |
deepseek_service: Optional[DeepSeekService] = None
|
| 30 |
|
| 31 |
-
@asynccontextmanager
|
| 32 |
-
async def lifespan(app: FastAPI):
|
| 33 |
-
"""应用生命周期管理"""
|
| 34 |
-
# 启动时初始化 DeepSeek 服务
|
| 35 |
-
await init_deepseek()
|
| 36 |
-
yield
|
| 37 |
-
print("\nShutting down service...")
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
description="智能皮肤诊断助手",
|
| 42 |
-
version="1.0.0",
|
| 43 |
-
lifespan=lifespan
|
| 44 |
-
)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
CORSMiddleware,
|
| 49 |
-
allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
|
| 50 |
-
allow_credentials=True,
|
| 51 |
-
allow_methods=["*"],
|
| 52 |
-
allow_headers=["*"],
|
| 53 |
-
)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
# pending_images: 存储已上传但尚未发送给LLM的图片路径 (State ID -> Image Path)
|
| 58 |
-
chat_states = {}
|
| 59 |
-
pending_images = {}
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
解析诊断结果中的think和answer标签
|
| 64 |
-
|
| 65 |
-
参数:
|
| 66 |
-
- raw_text: 原始诊断文本
|
| 67 |
-
|
| 68 |
-
返回:
|
| 69 |
-
- dict: 包含thinking, answer, raw字段的字典
|
| 70 |
-
"""
|
| 71 |
-
import re
|
| 72 |
-
|
| 73 |
-
# 尝试匹配完整的标签
|
| 74 |
-
think_match = re.search(r'<think>([\s\S]*?)</think>', raw_text)
|
| 75 |
-
answer_match = re.search(r'<answer>([\s\S]*?)</answer>', raw_text)
|
| 76 |
-
|
| 77 |
-
thinking = None
|
| 78 |
-
answer = None
|
| 79 |
-
|
| 80 |
-
# 处理think标签
|
| 81 |
-
if think_match:
|
| 82 |
-
thinking = think_match.group(1).strip()
|
| 83 |
-
else:
|
| 84 |
-
# 尝试匹配未闭合的think标签(输出被截断的情况)
|
| 85 |
-
unclosed_think = re.search(r'<think>([\s\S]*?)(?=<answer>|$)', raw_text)
|
| 86 |
if unclosed_think:
|
| 87 |
thinking = unclosed_think.group(1).strip()
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
answer = answer_match.group(1).strip()
|
| 92 |
-
else:
|
| 93 |
-
# 尝试匹配未闭合的answer标签
|
| 94 |
-
unclosed_answer = re.search(r'<answer>([\s\S]*?)$', raw_text)
|
| 95 |
if unclosed_answer:
|
| 96 |
answer = unclosed_answer.group(1).strip()
|
| 97 |
-
|
| 98 |
-
# 如果仍然没有找到answer,清理原始文本作为answer
|
| 99 |
if not answer:
|
| 100 |
-
|
| 101 |
-
cleaned = re.sub(r
|
| 102 |
-
cleaned = re.sub(r
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
answer = cleaned if cleaned else raw_text
|
| 106 |
-
|
| 107 |
-
# 清理可能残留的标签
|
| 108 |
-
if answer:
|
| 109 |
-
answer = re.sub(r'</?think>|</?answer>', '', answer).strip()
|
| 110 |
-
if thinking:
|
| 111 |
-
thinking = re.sub(r'</?think>|</?answer>', '', thinking).strip()
|
| 112 |
-
|
| 113 |
-
# 处理 "Final Answer:" 格式,提取其后的内容
|
| 114 |
if answer:
|
| 115 |
-
|
|
|
|
| 116 |
if final_answer_match:
|
| 117 |
answer = final_answer_match.group(1).strip()
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
|
| 125 |
print("Initializing Model Service...")
|
| 126 |
-
# 全局加载模型
|
| 127 |
gpt_model = SkinGPTModel(MODEL_PATH)
|
| 128 |
print("Service Ready.")
|
| 129 |
|
| 130 |
-
|
| 131 |
async def init_deepseek():
|
| 132 |
global deepseek_service
|
| 133 |
print("\nInitializing DeepSeek service...")
|
|
@@ -137,120 +85,116 @@ async def init_deepseek():
|
|
| 137 |
else:
|
| 138 |
print("DeepSeek service not available, will return raw results")
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
@app.post("/v1/upload/{state_id}")
|
| 141 |
async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
|
| 142 |
-
|
| 143 |
-
接收图片上传。
|
| 144 |
-
逻辑:将图片保存到本地临时目录,并标记该 state_id 有一张待处理图片。
|
| 145 |
-
"""
|
| 146 |
try:
|
| 147 |
-
# 1. 保存图片到本地临时文件
|
| 148 |
file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
|
| 149 |
unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
|
| 150 |
-
file_path =
|
| 151 |
-
|
| 152 |
-
with open(
|
| 153 |
shutil.copyfileobj(file.file, buffer)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
pending_images[state_id] = file_path
|
| 158 |
-
|
| 159 |
-
# 3. 初始化对话状态(如果是新会话)
|
| 160 |
if state_id not in chat_states:
|
| 161 |
chat_states[state_id] = []
|
| 162 |
-
|
| 163 |
-
return {"message": "Image uploaded successfully", "path": file_path}
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
|
| 168 |
@app.post("/v1/predict/{state_id}")
|
| 169 |
async def v1_predict(request: Request, state_id: str):
|
| 170 |
-
"""
|
| 171 |
-
接收文本并执行推理。
|
| 172 |
-
逻辑:检查是否有待处理图片。如果有,将其与文本组合成 multimodal 消息。
|
| 173 |
-
"""
|
| 174 |
try:
|
| 175 |
data = await request.json()
|
| 176 |
-
except:
|
| 177 |
-
raise HTTPException(status_code=400, detail="Invalid JSON")
|
| 178 |
-
|
| 179 |
user_message = data.get("message", "")
|
| 180 |
if not user_message:
|
| 181 |
raise HTTPException(status_code=400, detail="Missing 'message' field")
|
| 182 |
|
| 183 |
-
# 获取或初始化历史
|
| 184 |
history = chat_states.get(state_id, [])
|
| 185 |
-
|
| 186 |
-
# 构建当前轮次的用户内容
|
| 187 |
current_content = []
|
| 188 |
-
|
| 189 |
-
# 1. 检查是否有刚刚上传的图片
|
| 190 |
if state_id in pending_images:
|
| 191 |
-
img_path = pending_images.pop(state_id)
|
| 192 |
current_content.append({"type": "image", "image": img_path})
|
| 193 |
-
|
| 194 |
-
# 如果是第一次对话,加上 System Prompt
|
| 195 |
if not history:
|
| 196 |
-
|
| 197 |
-
user_message = f"{system_prompt}\n\n{user_message}"
|
| 198 |
|
| 199 |
-
# 2. 添加文本
|
| 200 |
current_content.append({"type": "text", "text": user_message})
|
| 201 |
-
|
| 202 |
-
# 3. 更新历史
|
| 203 |
history.append({"role": "user", "content": current_content})
|
| 204 |
chat_states[state_id] = history
|
| 205 |
|
| 206 |
-
# 4. 运行推理 (在线程池中运行以防阻塞)
|
| 207 |
try:
|
| 208 |
-
response_text = await run_in_threadpool(
|
| 209 |
-
|
| 210 |
-
messages=history
|
| 211 |
-
)
|
| 212 |
-
except Exception as e:
|
| 213 |
-
# 回滚历史(移除刚才出错的用户提问)
|
| 214 |
chat_states[state_id].pop()
|
| 215 |
-
raise HTTPException(status_code=500, detail=f"Inference error: {
|
| 216 |
|
| 217 |
-
# 5. 将回复加入历史
|
| 218 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
|
| 219 |
chat_states[state_id] = history
|
| 220 |
-
|
| 221 |
return {"message": response_text}
|
| 222 |
|
|
|
|
| 223 |
@app.post("/v1/reset/{state_id}")
|
| 224 |
async def reset_chat(state_id: str):
|
| 225 |
-
"""清除会话状态"""
|
| 226 |
if state_id in chat_states:
|
| 227 |
del chat_states[state_id]
|
| 228 |
if state_id in pending_images:
|
| 229 |
-
# 可选:删除临时文件
|
| 230 |
try:
|
| 231 |
-
|
| 232 |
-
except:
|
| 233 |
pass
|
| 234 |
del pending_images[state_id]
|
| 235 |
return {"message": "Chat history reset"}
|
| 236 |
|
|
|
|
| 237 |
@app.get("/")
|
| 238 |
async def root():
|
| 239 |
-
"""根路径"""
|
| 240 |
return {
|
| 241 |
-
"name": "SkinGPT-R1
|
| 242 |
-
"version": "1.
|
| 243 |
"status": "running",
|
| 244 |
-
"description": "
|
| 245 |
}
|
| 246 |
|
|
|
|
| 247 |
@app.get("/health")
|
| 248 |
async def health_check():
|
| 249 |
-
"""
|
| 250 |
-
|
| 251 |
-
"status": "healthy",
|
| 252 |
-
"model_loaded": True
|
| 253 |
-
}
|
| 254 |
|
| 255 |
@app.post("/diagnose/stream")
|
| 256 |
async def diagnose_stream(
|
|
@@ -258,126 +202,89 @@ async def diagnose_stream(
|
|
| 258 |
text: str = Form(...),
|
| 259 |
language: str = Form("zh"),
|
| 260 |
):
|
| 261 |
-
"""
|
| 262 |
-
SSE流式诊断接口(用于前端)
|
| 263 |
-
支持图片上传和文本输入,返回真正的流式响应
|
| 264 |
-
使用 DeepSeek API 优化输出格式
|
| 265 |
-
"""
|
| 266 |
-
from queue import Queue, Empty
|
| 267 |
-
from threading import Thread
|
| 268 |
-
|
| 269 |
language = language if language in ("zh", "en") else "zh"
|
| 270 |
-
|
| 271 |
-
# 处理图片
|
| 272 |
pil_image = None
|
| 273 |
-
|
| 274 |
-
|
| 275 |
if image:
|
| 276 |
contents = await image.read()
|
| 277 |
pil_image = Image.open(BytesIO(contents)).convert("RGB")
|
| 278 |
-
|
| 279 |
-
# 创建队列用于线程间通信
|
| 280 |
result_queue = Queue()
|
| 281 |
-
# 用于存储完整响应和解析结果
|
| 282 |
generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
|
| 283 |
-
|
| 284 |
def run_generation():
|
| 285 |
-
"""在后台线程中运行流式生成"""
|
| 286 |
full_response = []
|
| 287 |
-
|
| 288 |
try:
|
| 289 |
-
# 构建消息
|
| 290 |
messages = []
|
| 291 |
current_content = []
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
| 297 |
if pil_image:
|
| 298 |
-
|
| 299 |
-
pil_image.save(
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
current_content.append({"type": "text", "text": prompt})
|
| 305 |
messages.append({"role": "user", "content": current_content})
|
| 306 |
-
|
| 307 |
-
# 流式生成 - 每个 chunk 立即放入队列
|
| 308 |
for chunk in gpt_model.generate_response_stream(
|
| 309 |
messages=messages,
|
| 310 |
max_new_tokens=2048,
|
| 311 |
-
temperature=0.7
|
| 312 |
):
|
| 313 |
full_response.append(chunk)
|
| 314 |
result_queue.put(("delta", chunk))
|
| 315 |
-
|
| 316 |
-
# 解析结果
|
| 317 |
response_text = "".join(full_response)
|
| 318 |
-
parsed = parse_diagnosis_result(response_text)
|
| 319 |
generation_result["full_response"] = full_response
|
| 320 |
-
generation_result["parsed"] =
|
| 321 |
-
|
| 322 |
-
# 标记生成完成
|
| 323 |
result_queue.put(("generation_done", None))
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
async def event_generator():
|
| 329 |
-
"""异步生成SSE事件"""
|
| 330 |
-
# 在后台线程启动生成(非阻塞)
|
| 331 |
gen_thread = Thread(target=run_generation)
|
| 332 |
gen_thread.start()
|
| 333 |
-
|
| 334 |
loop = asyncio.get_event_loop()
|
| 335 |
-
|
| 336 |
-
# 从队列中读取并发送流式内容
|
| 337 |
while True:
|
| 338 |
try:
|
| 339 |
-
# 非阻塞获取
|
| 340 |
msg_type, data = await loop.run_in_executor(
|
| 341 |
-
None,
|
| 342 |
-
lambda: result_queue.get(timeout=0.1)
|
| 343 |
)
|
| 344 |
-
|
| 345 |
if msg_type == "generation_done":
|
| 346 |
-
# 流式生成完成,准备处理最终结果
|
| 347 |
break
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
yield f"data: {yield_chunk}\n\n"
|
| 351 |
elif msg_type == "error":
|
| 352 |
yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
|
| 353 |
gen_thread.join()
|
| 354 |
return
|
| 355 |
-
|
| 356 |
except Empty:
|
| 357 |
-
# 队列暂时为空,继续等待
|
| 358 |
await asyncio.sleep(0.01)
|
| 359 |
-
|
| 360 |
-
|
| 361 |
gen_thread.join()
|
| 362 |
-
|
| 363 |
-
# 获取解析结果
|
| 364 |
parsed = generation_result["parsed"]
|
| 365 |
if not parsed:
|
| 366 |
-
yield
|
| 367 |
return
|
| 368 |
-
|
| 369 |
raw_thinking = parsed["thinking"]
|
| 370 |
raw_answer = parsed["answer"]
|
| 371 |
-
|
| 372 |
-
# 使用 DeepSeek 优化结果
|
| 373 |
refined_by_deepseek = False
|
| 374 |
description = None
|
| 375 |
thinking = raw_thinking
|
| 376 |
answer = raw_answer
|
| 377 |
-
|
| 378 |
if deepseek_service and deepseek_service.is_loaded:
|
| 379 |
try:
|
| 380 |
-
print(f"Calling DeepSeek to refine diagnosis (language={language})...")
|
| 381 |
refined = await deepseek_service.refine_diagnosis(
|
| 382 |
raw_answer=raw_answer,
|
| 383 |
raw_thinking=raw_thinking,
|
|
@@ -388,36 +295,35 @@ async def diagnose_stream(
|
|
| 388 |
thinking = refined["analysis_process"]
|
| 389 |
answer = refined["diagnosis_result"]
|
| 390 |
refined_by_deepseek = True
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
print(f"DeepSeek refinement failed, using original: {e}")
|
| 394 |
else:
|
| 395 |
print("DeepSeek service not available, using raw results")
|
| 396 |
-
|
| 397 |
-
success_msg = "Diagnosis completed" if language == "en" else "诊断完成"
|
| 398 |
-
|
| 399 |
-
# 返回格式与参考项目保持一致
|
| 400 |
final_payload = {
|
| 401 |
-
"description": description,
|
| 402 |
-
"thinking": thinking,
|
| 403 |
-
"answer": answer,
|
| 404 |
-
"raw": parsed["raw"],
|
| 405 |
-
"refined_by_deepseek": refined_by_deepseek,
|
| 406 |
"success": True,
|
| 407 |
-
"message":
|
| 408 |
}
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
# 清理临时图片
|
| 413 |
temp_path = generation_result.get("temp_image_path")
|
| 414 |
-
if temp_path
|
| 415 |
try:
|
| 416 |
-
|
| 417 |
-
except:
|
| 418 |
pass
|
| 419 |
-
|
| 420 |
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
| 421 |
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
import os
|
| 6 |
import shutil
|
| 7 |
import uuid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from contextlib import asynccontextmanager
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from queue import Empty, Queue
|
| 12 |
+
from threading import Thread
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import uvicorn
|
| 16 |
+
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
| 17 |
+
from fastapi.concurrency import run_in_threadpool
|
| 18 |
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
from fastapi.responses import StreamingResponse
|
| 20 |
+
from PIL import Image
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
try:
|
| 23 |
+
from .deepseek_service import DeepSeekService, get_deepseek_service
|
| 24 |
+
from .model_utils import DEFAULT_MODEL_PATH, SkinGPTModel, resolve_model_path
|
| 25 |
+
except ImportError:
|
| 26 |
+
from deepseek_service import DeepSeekService, get_deepseek_service
|
| 27 |
+
from model_utils import DEFAULT_MODEL_PATH, SkinGPTModel, resolve_model_path
|
| 28 |
|
| 29 |
+
MODEL_PATH = resolve_model_path(DEFAULT_MODEL_PATH)
|
| 30 |
+
TEMP_DIR = Path(__file__).resolve().parents[1] / "temp_uploads"
|
| 31 |
+
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
|
| 33 |
|
|
|
|
| 34 |
deepseek_service: Optional[DeepSeekService] = None
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
def parse_diagnosis_result(raw_text: str) -> dict:
|
| 38 |
+
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
think_match = re.search(r"<think>([\s\S]*?)</think>", raw_text)
|
| 41 |
+
answer_match = re.search(r"<answer>([\s\S]*?)</answer>", raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
thinking = think_match.group(1).strip() if think_match else None
|
| 44 |
+
answer = answer_match.group(1).strip() if answer_match else None
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
if not thinking:
|
| 47 |
+
unclosed_think = re.search(r"<think>([\s\S]*?)(?=<answer>|$)", raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
if unclosed_think:
|
| 49 |
thinking = unclosed_think.group(1).strip()
|
| 50 |
+
|
| 51 |
+
if not answer:
|
| 52 |
+
unclosed_answer = re.search(r"<answer>([\s\S]*?)$", raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
if unclosed_answer:
|
| 54 |
answer = unclosed_answer.group(1).strip()
|
| 55 |
+
|
|
|
|
| 56 |
if not answer:
|
| 57 |
+
cleaned = re.sub(r"<think>[\s\S]*?</think>", "", raw_text)
|
| 58 |
+
cleaned = re.sub(r"<think>[\s\S]*", "", cleaned)
|
| 59 |
+
cleaned = re.sub(r"</?answer>", "", cleaned)
|
| 60 |
+
answer = cleaned.strip() or raw_text
|
| 61 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
if answer:
|
| 63 |
+
answer = re.sub(r"</?think>|</?answer>", "", answer).strip()
|
| 64 |
+
final_answer_match = re.search(r"Final Answer:\s*([\s\S]*)", answer, re.IGNORECASE)
|
| 65 |
if final_answer_match:
|
| 66 |
answer = final_answer_match.group(1).strip()
|
| 67 |
+
|
| 68 |
+
if thinking:
|
| 69 |
+
thinking = re.sub(r"</?think>|</?answer>", "", thinking).strip()
|
| 70 |
+
|
| 71 |
+
return {"thinking": thinking or None, "answer": answer, "raw": raw_text}
|
| 72 |
+
|
| 73 |
|
| 74 |
print("Initializing Model Service...")
|
|
|
|
| 75 |
gpt_model = SkinGPTModel(MODEL_PATH)
|
| 76 |
print("Service Ready.")
|
| 77 |
|
| 78 |
+
|
| 79 |
async def init_deepseek():
|
| 80 |
global deepseek_service
|
| 81 |
print("\nInitializing DeepSeek service...")
|
|
|
|
| 85 |
else:
|
| 86 |
print("DeepSeek service not available, will return raw results")
|
| 87 |
|
| 88 |
+
|
| 89 |
+
@asynccontextmanager
|
| 90 |
+
async def lifespan(app: FastAPI):
|
| 91 |
+
await init_deepseek()
|
| 92 |
+
yield
|
| 93 |
+
print("\nShutting down service...")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
app = FastAPI(
|
| 97 |
+
title="SkinGPT-R1 Full Precision API",
|
| 98 |
+
description="Full-precision dermatology assistant backend",
|
| 99 |
+
version="1.1.0",
|
| 100 |
+
lifespan=lifespan,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
app.add_middleware(
|
| 104 |
+
CORSMiddleware,
|
| 105 |
+
allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
|
| 106 |
+
allow_credentials=True,
|
| 107 |
+
allow_methods=["*"],
|
| 108 |
+
allow_headers=["*"],
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
chat_states = {}
|
| 112 |
+
pending_images = {}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
@app.post("/v1/upload/{state_id}")
|
| 116 |
async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
|
| 117 |
+
del survey
|
|
|
|
|
|
|
|
|
|
| 118 |
try:
|
|
|
|
| 119 |
file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
|
| 120 |
unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
|
| 121 |
+
file_path = TEMP_DIR / unique_name
|
| 122 |
+
|
| 123 |
+
with file_path.open("wb") as buffer:
|
| 124 |
shutil.copyfileobj(file.file, buffer)
|
| 125 |
+
|
| 126 |
+
pending_images[state_id] = str(file_path)
|
| 127 |
+
|
|
|
|
|
|
|
|
|
|
| 128 |
if state_id not in chat_states:
|
| 129 |
chat_states[state_id] = []
|
| 130 |
+
|
| 131 |
+
return {"message": "Image uploaded successfully", "path": str(file_path)}
|
| 132 |
+
except Exception as exc:
|
| 133 |
+
raise HTTPException(status_code=500, detail=f"Upload failed: {exc}") from exc
|
| 134 |
+
|
| 135 |
|
| 136 |
@app.post("/v1/predict/{state_id}")
|
| 137 |
async def v1_predict(request: Request, state_id: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
try:
|
| 139 |
data = await request.json()
|
| 140 |
+
except Exception as exc:
|
| 141 |
+
raise HTTPException(status_code=400, detail="Invalid JSON") from exc
|
| 142 |
+
|
| 143 |
user_message = data.get("message", "")
|
| 144 |
if not user_message:
|
| 145 |
raise HTTPException(status_code=400, detail="Missing 'message' field")
|
| 146 |
|
|
|
|
| 147 |
history = chat_states.get(state_id, [])
|
|
|
|
|
|
|
| 148 |
current_content = []
|
| 149 |
+
|
|
|
|
| 150 |
if state_id in pending_images:
|
| 151 |
+
img_path = pending_images.pop(state_id)
|
| 152 |
current_content.append({"type": "image", "image": img_path})
|
|
|
|
|
|
|
| 153 |
if not history:
|
| 154 |
+
user_message = f"You are a professional AI dermatology assistant.\n\n{user_message}"
|
|
|
|
| 155 |
|
|
|
|
| 156 |
current_content.append({"type": "text", "text": user_message})
|
|
|
|
|
|
|
| 157 |
history.append({"role": "user", "content": current_content})
|
| 158 |
chat_states[state_id] = history
|
| 159 |
|
|
|
|
| 160 |
try:
|
| 161 |
+
response_text = await run_in_threadpool(gpt_model.generate_response, messages=history)
|
| 162 |
+
except Exception as exc:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
chat_states[state_id].pop()
|
| 164 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc
|
| 165 |
|
|
|
|
| 166 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
|
| 167 |
chat_states[state_id] = history
|
|
|
|
| 168 |
return {"message": response_text}
|
| 169 |
|
| 170 |
+
|
| 171 |
@app.post("/v1/reset/{state_id}")
|
| 172 |
async def reset_chat(state_id: str):
|
|
|
|
| 173 |
if state_id in chat_states:
|
| 174 |
del chat_states[state_id]
|
| 175 |
if state_id in pending_images:
|
|
|
|
| 176 |
try:
|
| 177 |
+
Path(pending_images[state_id]).unlink(missing_ok=True)
|
| 178 |
+
except Exception:
|
| 179 |
pass
|
| 180 |
del pending_images[state_id]
|
| 181 |
return {"message": "Chat history reset"}
|
| 182 |
|
| 183 |
+
|
| 184 |
@app.get("/")
|
| 185 |
async def root():
|
|
|
|
| 186 |
return {
|
| 187 |
+
"name": "SkinGPT-R1 Full Precision API",
|
| 188 |
+
"version": "1.1.0",
|
| 189 |
"status": "running",
|
| 190 |
+
"description": "Full-precision dermatology assistant",
|
| 191 |
}
|
| 192 |
|
| 193 |
+
|
| 194 |
@app.get("/health")
|
| 195 |
async def health_check():
|
| 196 |
+
return {"status": "healthy", "model_loaded": True}
|
| 197 |
+
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
@app.post("/diagnose/stream")
|
| 200 |
async def diagnose_stream(
|
|
|
|
| 202 |
text: str = Form(...),
|
| 203 |
language: str = Form("zh"),
|
| 204 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
language = language if language in ("zh", "en") else "zh"
|
|
|
|
|
|
|
| 206 |
pil_image = None
|
| 207 |
+
|
|
|
|
| 208 |
if image:
|
| 209 |
contents = await image.read()
|
| 210 |
pil_image = Image.open(BytesIO(contents)).convert("RGB")
|
| 211 |
+
|
|
|
|
| 212 |
result_queue = Queue()
|
|
|
|
| 213 |
generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
|
| 214 |
+
|
| 215 |
def run_generation():
|
|
|
|
| 216 |
full_response = []
|
|
|
|
| 217 |
try:
|
|
|
|
| 218 |
messages = []
|
| 219 |
current_content = []
|
| 220 |
+
system_prompt = (
|
| 221 |
+
"You are a professional AI dermatology assistant."
|
| 222 |
+
if language == "en"
|
| 223 |
+
else "你是一个专业的AI皮肤科助手。"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
if pil_image:
|
| 227 |
+
temp_image_path = TEMP_DIR / f"temp_{uuid.uuid4().hex}.jpg"
|
| 228 |
+
pil_image.save(temp_image_path)
|
| 229 |
+
generation_result["temp_image_path"] = str(temp_image_path)
|
| 230 |
+
current_content.append({"type": "image", "image": str(temp_image_path)})
|
| 231 |
+
|
| 232 |
+
current_content.append({"type": "text", "text": f"{system_prompt}\n\n{text}"})
|
|
|
|
| 233 |
messages.append({"role": "user", "content": current_content})
|
| 234 |
+
|
|
|
|
| 235 |
for chunk in gpt_model.generate_response_stream(
|
| 236 |
messages=messages,
|
| 237 |
max_new_tokens=2048,
|
| 238 |
+
temperature=0.7,
|
| 239 |
):
|
| 240 |
full_response.append(chunk)
|
| 241 |
result_queue.put(("delta", chunk))
|
| 242 |
+
|
|
|
|
| 243 |
response_text = "".join(full_response)
|
|
|
|
| 244 |
generation_result["full_response"] = full_response
|
| 245 |
+
generation_result["parsed"] = parse_diagnosis_result(response_text)
|
|
|
|
|
|
|
| 246 |
result_queue.put(("generation_done", None))
|
| 247 |
+
except Exception as exc:
|
| 248 |
+
result_queue.put(("error", str(exc)))
|
| 249 |
+
|
|
|
|
| 250 |
async def event_generator():
|
|
|
|
|
|
|
| 251 |
gen_thread = Thread(target=run_generation)
|
| 252 |
gen_thread.start()
|
| 253 |
+
|
| 254 |
loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
| 255 |
while True:
|
| 256 |
try:
|
|
|
|
| 257 |
msg_type, data = await loop.run_in_executor(
|
| 258 |
+
None,
|
| 259 |
+
lambda: result_queue.get(timeout=0.1),
|
| 260 |
)
|
|
|
|
| 261 |
if msg_type == "generation_done":
|
|
|
|
| 262 |
break
|
| 263 |
+
if msg_type == "delta":
|
| 264 |
+
yield f"data: {json.dumps({'type': 'delta', 'text': data}, ensure_ascii=False)}\n\n"
|
|
|
|
| 265 |
elif msg_type == "error":
|
| 266 |
yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
|
| 267 |
gen_thread.join()
|
| 268 |
return
|
|
|
|
| 269 |
except Empty:
|
|
|
|
| 270 |
await asyncio.sleep(0.01)
|
| 271 |
+
|
|
|
|
| 272 |
gen_thread.join()
|
| 273 |
+
|
|
|
|
| 274 |
parsed = generation_result["parsed"]
|
| 275 |
if not parsed:
|
| 276 |
+
yield "data: {\"type\": \"error\", \"message\": \"Failed to parse response\"}\n\n"
|
| 277 |
return
|
| 278 |
+
|
| 279 |
raw_thinking = parsed["thinking"]
|
| 280 |
raw_answer = parsed["answer"]
|
|
|
|
|
|
|
| 281 |
refined_by_deepseek = False
|
| 282 |
description = None
|
| 283 |
thinking = raw_thinking
|
| 284 |
answer = raw_answer
|
| 285 |
+
|
| 286 |
if deepseek_service and deepseek_service.is_loaded:
|
| 287 |
try:
|
|
|
|
| 288 |
refined = await deepseek_service.refine_diagnosis(
|
| 289 |
raw_answer=raw_answer,
|
| 290 |
raw_thinking=raw_thinking,
|
|
|
|
| 295 |
thinking = refined["analysis_process"]
|
| 296 |
answer = refined["diagnosis_result"]
|
| 297 |
refined_by_deepseek = True
|
| 298 |
+
except Exception as exc:
|
| 299 |
+
print(f"DeepSeek refinement failed, using original: {exc}")
|
|
|
|
| 300 |
else:
|
| 301 |
print("DeepSeek service not available, using raw results")
|
| 302 |
+
|
|
|
|
|
|
|
|
|
|
| 303 |
final_payload = {
|
| 304 |
+
"description": description,
|
| 305 |
+
"thinking": thinking,
|
| 306 |
+
"answer": answer,
|
| 307 |
+
"raw": parsed["raw"],
|
| 308 |
+
"refined_by_deepseek": refined_by_deepseek,
|
| 309 |
"success": True,
|
| 310 |
+
"message": "Diagnosis completed" if language == "en" else "诊断完成",
|
| 311 |
}
|
| 312 |
+
yield f"data: {json.dumps({'type': 'final', 'result': final_payload}, ensure_ascii=False)}\n\n"
|
| 313 |
+
|
|
|
|
|
|
|
| 314 |
temp_path = generation_result.get("temp_image_path")
|
| 315 |
+
if temp_path:
|
| 316 |
try:
|
| 317 |
+
Path(temp_path).unlink(missing_ok=True)
|
| 318 |
+
except Exception:
|
| 319 |
pass
|
| 320 |
+
|
| 321 |
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
| 322 |
|
| 323 |
+
|
| 324 |
+
def main() -> None:
|
| 325 |
+
uvicorn.run("app:app", host="0.0.0.0", port=5900, reload=False)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if __name__ == "__main__":
|
| 329 |
+
main()
|
inference/{chat.py → full_precision/chat.py}
RENAMED
|
@@ -1,48 +1,53 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
import argparse
|
| 3 |
-
import
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
parser.
|
|
|
|
| 9 |
parser.add_argument("--image", type=str, required=True, help="Path to initial image")
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
| 16 |
-
# 系统提示词
|
| 17 |
-
system_prompt = "You are a professional AI dermatology assistant. Analyze the skin condition carefully."
|
| 18 |
-
|
| 19 |
-
# 构造第一条包含图片的消息
|
| 20 |
-
if not os.path.exists(args.image):
|
| 21 |
print(f"Error: Image {args.image} not found.")
|
| 22 |
return
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
]
|
| 31 |
-
}
|
| 32 |
-
]
|
| 33 |
|
| 34 |
print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
|
| 35 |
print(f"Image loaded: {args.image}")
|
| 36 |
-
|
| 37 |
-
# 获取第一轮诊断
|
| 38 |
print("\nModel is thinking...", end="", flush=True)
|
| 39 |
-
response =
|
| 40 |
print(f"\rAssistant: {response}\n")
|
| 41 |
-
|
| 42 |
-
# 将助手的回复加入历史
|
| 43 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 44 |
|
| 45 |
-
# 进入多轮对话循环
|
| 46 |
while True:
|
| 47 |
try:
|
| 48 |
user_input = input("User: ")
|
|
@@ -51,18 +56,16 @@ def main():
|
|
| 51 |
if not user_input.strip():
|
| 52 |
continue
|
| 53 |
|
| 54 |
-
# 加入用户的新问题
|
| 55 |
history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
|
| 56 |
|
| 57 |
print("Model is thinking...", end="", flush=True)
|
| 58 |
-
response =
|
| 59 |
print(f"\rAssistant: {response}\n")
|
| 60 |
|
| 61 |
-
# 加入助手的新回复
|
| 62 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 63 |
-
|
| 64 |
except KeyboardInterrupt:
|
| 65 |
break
|
| 66 |
|
|
|
|
| 67 |
if __name__ == "__main__":
|
| 68 |
-
main()
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from .model_utils import (
|
| 8 |
+
DEFAULT_MODEL_PATH,
|
| 9 |
+
SkinGPTModel,
|
| 10 |
+
build_single_turn_messages,
|
| 11 |
+
resolve_model_path,
|
| 12 |
+
)
|
| 13 |
+
except ImportError:
|
| 14 |
+
from model_utils import (
|
| 15 |
+
DEFAULT_MODEL_PATH,
|
| 16 |
+
SkinGPTModel,
|
| 17 |
+
build_single_turn_messages,
|
| 18 |
+
resolve_model_path,
|
| 19 |
+
)
|
| 20 |
|
| 21 |
+
|
| 22 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 23 |
+
parser = argparse.ArgumentParser(description="SkinGPT-R1 full-precision multi-turn chat")
|
| 24 |
+
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
|
| 25 |
parser.add_argument("--image", type=str, required=True, help="Path to initial image")
|
| 26 |
+
return parser
|
| 27 |
+
|
| 28 |
|
| 29 |
+
def main() -> None:
|
| 30 |
+
args = build_parser().parse_args()
|
| 31 |
|
| 32 |
+
if not Path(args.image).exists():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
print(f"Error: Image {args.image} not found.")
|
| 34 |
return
|
| 35 |
|
| 36 |
+
model = SkinGPTModel(resolve_model_path(args.model_path))
|
| 37 |
+
history = build_single_turn_messages(
|
| 38 |
+
args.image,
|
| 39 |
+
"Please analyze this image.",
|
| 40 |
+
system_prompt="You are a professional AI dermatology assistant. Analyze the skin condition carefully.",
|
| 41 |
+
)
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
|
| 44 |
print(f"Image loaded: {args.image}")
|
| 45 |
+
|
|
|
|
| 46 |
print("\nModel is thinking...", end="", flush=True)
|
| 47 |
+
response = model.generate_response(history)
|
| 48 |
print(f"\rAssistant: {response}\n")
|
|
|
|
|
|
|
| 49 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 50 |
|
|
|
|
| 51 |
while True:
|
| 52 |
try:
|
| 53 |
user_input = input("User: ")
|
|
|
|
| 56 |
if not user_input.strip():
|
| 57 |
continue
|
| 58 |
|
|
|
|
| 59 |
history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
|
| 60 |
|
| 61 |
print("Model is thinking...", end="", flush=True)
|
| 62 |
+
response = model.generate_response(history)
|
| 63 |
print(f"\rAssistant: {response}\n")
|
| 64 |
|
|
|
|
| 65 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
|
|
|
| 66 |
except KeyboardInterrupt:
|
| 67 |
break
|
| 68 |
|
| 69 |
+
|
| 70 |
if __name__ == "__main__":
|
| 71 |
+
main()
|
inference/{deepseek_service.py → full_precision/deepseek_service.py}
RENAMED
|
@@ -1,75 +1,51 @@
|
|
| 1 |
-
|
| 2 |
-
DeepSeek API Service
|
| 3 |
-
Used to optimize and organize SkinGPT model output results
|
| 4 |
-
"""
|
| 5 |
|
| 6 |
import os
|
| 7 |
import re
|
| 8 |
from typing import Optional
|
|
|
|
| 9 |
from openai import AsyncOpenAI
|
| 10 |
|
| 11 |
|
| 12 |
class DeepSeekService:
|
| 13 |
-
"""DeepSeek
|
| 14 |
-
|
| 15 |
def __init__(self, api_key: Optional[str] = None):
|
| 16 |
-
"""
|
| 17 |
-
Initialize DeepSeek service
|
| 18 |
-
|
| 19 |
-
Parameters:
|
| 20 |
-
api_key: DeepSeek API key, reads from environment variable if not provided
|
| 21 |
-
"""
|
| 22 |
self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
|
| 23 |
self.base_url = "https://api.deepseek.com"
|
| 24 |
-
self.model = "deepseek-chat"
|
| 25 |
-
|
| 26 |
self.client = None
|
| 27 |
self.is_loaded = False
|
| 28 |
-
|
| 29 |
-
print(
|
| 30 |
print(f"API Base URL: {self.base_url}")
|
| 31 |
-
|
| 32 |
async def load(self):
|
| 33 |
-
"""Initialize DeepSeek API client"""
|
| 34 |
try:
|
| 35 |
if not self.api_key:
|
| 36 |
print("DeepSeek API key not provided")
|
| 37 |
self.is_loaded = False
|
| 38 |
return
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
self.client = AsyncOpenAI(
|
| 42 |
-
api_key=self.api_key,
|
| 43 |
-
base_url=self.base_url
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
self.is_loaded = True
|
| 47 |
print("DeepSeek API service is ready!")
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
print(f"DeepSeek API service initialization failed: {e}")
|
| 51 |
self.is_loaded = False
|
| 52 |
-
|
| 53 |
async def refine_diagnosis(
|
| 54 |
-
self,
|
| 55 |
raw_answer: str,
|
| 56 |
raw_thinking: Optional[str] = None,
|
| 57 |
-
language: str = "zh"
|
| 58 |
) -> dict:
|
| 59 |
-
"""
|
| 60 |
-
Use DeepSeek API to optimize and organize diagnosis results
|
| 61 |
-
|
| 62 |
-
Parameters:
|
| 63 |
-
raw_answer: Original diagnosis result
|
| 64 |
-
raw_thinking: AI thinking process
|
| 65 |
-
language: Language option
|
| 66 |
-
|
| 67 |
-
Returns:
|
| 68 |
-
Dictionary containing "description", "analysis_process" and "diagnosis_result"
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
if not self.is_loaded or self.client is None:
|
| 72 |
-
error_msg =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
print("DeepSeek API not initialized, returning original result")
|
| 74 |
return {
|
| 75 |
"success": False,
|
|
@@ -77,74 +53,67 @@ class DeepSeekService:
|
|
| 77 |
"analysis_process": raw_thinking or error_msg,
|
| 78 |
"diagnosis_result": raw_answer,
|
| 79 |
"original_diagnosis": raw_answer,
|
| 80 |
-
"error": "DeepSeek API not initialized"
|
| 81 |
}
|
| 82 |
-
|
| 83 |
try:
|
| 84 |
-
# Build prompt
|
| 85 |
prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
| 94 |
response = await self.client.chat.completions.create(
|
| 95 |
model=self.model,
|
| 96 |
messages=[
|
| 97 |
{"role": "system", "content": system_content},
|
| 98 |
-
{"role": "user", "content": prompt}
|
| 99 |
],
|
| 100 |
temperature=0.1,
|
| 101 |
max_tokens=2048,
|
| 102 |
top_p=0.8,
|
| 103 |
)
|
| 104 |
-
|
| 105 |
-
# Extract generated text
|
| 106 |
generated_text = response.choices[0].message.content
|
| 107 |
-
|
| 108 |
-
# Parse output
|
| 109 |
parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
|
| 110 |
-
|
| 111 |
return {
|
| 112 |
"success": True,
|
| 113 |
"description": parsed["description"],
|
| 114 |
"analysis_process": parsed["analysis_process"],
|
| 115 |
"diagnosis_result": parsed["diagnosis_result"],
|
| 116 |
"original_diagnosis": raw_answer,
|
| 117 |
-
"raw_refined": generated_text
|
| 118 |
}
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
| 123 |
return {
|
| 124 |
"success": False,
|
| 125 |
"description": "",
|
| 126 |
"analysis_process": raw_thinking or error_msg,
|
| 127 |
"diagnosis_result": raw_answer,
|
| 128 |
"original_diagnosis": raw_answer,
|
| 129 |
-
"error": str(
|
| 130 |
}
|
| 131 |
-
|
| 132 |
-
def _build_refine_prompt(
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
language: Language option, "zh" for Chinese, "en" for English
|
| 140 |
-
|
| 141 |
-
Returns:
|
| 142 |
-
Built prompt
|
| 143 |
-
"""
|
| 144 |
if language == "en":
|
| 145 |
-
|
| 146 |
-
thinking_text = raw_thinking if raw_thinking else "No analysis process available."
|
| 147 |
-
prompt = f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
|
| 148 |
|
| 149 |
【Requirements】
|
| 150 |
- Preserve the original tone and expression style
|
|
@@ -171,21 +140,9 @@ class DeepSeekService:
|
|
| 171 |
|
| 172 |
## Diagnosis Result
|
| 173 |
(The organized diagnosis result from Text 2)
|
| 174 |
-
|
| 175 |
-
【Example】:
|
| 176 |
-
## Description
|
| 177 |
-
The image shows red inflamed patches on the skin with pustules and darker colored spots. The lesions appear as papules and pustules distributed across the affected area, with some showing signs of inflammation and possible post-inflammatory hyperpigmentation.
|
| 178 |
-
|
| 179 |
-
## Analysis Process
|
| 180 |
-
These findings are consistent with acne vulgaris, commonly seen during adolescence. The user's age aligns with typical onset for this condition. Treatment recommendations: over-the-counter medications such as benzoyl peroxide or topical antibiotics, avoiding picking at the skin, and consulting a dermatologist if severe. The goal is to control inflammation and prevent scarring.
|
| 181 |
-
|
| 182 |
-
## Diagnosis Result
|
| 183 |
-
Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition, especially during adolescence, when hormonal changes cause overactive sebaceous glands, which can easily clog pores and form acne. Pathological care recommendations: 1. Keep face clean, wash face 2-3 times daily, use gentle cleansing products. 2. Avoid squeezing acne with hands to prevent worsening inflammation or leaving scars. 3. Avoid using irritating cosmetics and skincare products. 4. Can use topical medications containing salicylic acid, benzoyl peroxide, etc. 5. If necessary, can use oral antibiotics or other treatment methods under doctor's guidance. Precautions: 1. Avoid rubbing or damaging the affected area to prevent infection. 2. Eat less oily and spicy foods, eat more vegetables and fruits. 3. Maintain good rest habits, avoid staying up late. 4. If acne symptoms persist without improvement or show signs of worsening, seek medical attention promptly.
|
| 184 |
"""
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
thinking_text = raw_thinking if raw_thinking else "No analysis process available."
|
| 188 |
-
prompt = f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
|
| 189 |
|
| 190 |
【要求】
|
| 191 |
- 保留原文的语气和表达方式
|
|
@@ -198,7 +155,6 @@ Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition,
|
|
| 198 |
- 禁止推断或添加新的医学信息,禁止输出任何元评论
|
| 199 |
- 可以调整不合理的语句或去除冗余内容以提高清晰度
|
| 200 |
|
| 201 |
-
|
| 202 |
【文本1】
|
| 203 |
{thinking_text}
|
| 204 |
|
|
@@ -214,171 +170,102 @@ Possible diagnosis: Acne (pimples) Explanation: Acne is a common skin condition,
|
|
| 214 |
|
| 215 |
## 诊断结果
|
| 216 |
(整理后的诊断结果)
|
| 217 |
-
|
| 218 |
-
【样例】:
|
| 219 |
-
## 图像描述
|
| 220 |
-
图片显示皮肤上有红色发炎的斑块,伴有脓疱和颜色较深的斑点。病变表现为分布在受影响区域的丘疹和脓疱,部分显示出炎症迹象和可能的炎症后色素沉着。
|
| 221 |
-
|
| 222 |
-
## 分析过程
|
| 223 |
-
这些表现符合寻常痤疮的特征,青春期常见。用户的年龄与该病症的典型发病年龄相符。治疗建议:使用非处方药物如过氧化苯甲酰或外用抗生素,避免抠抓皮肤,病情严重时咨询皮肤科医生。目标是控制炎症并防止疤痕形成。
|
| 224 |
-
|
| 225 |
-
## 诊断结果
|
| 226 |
-
可能的诊断:痤疮(青春痘) 解释:痤疮是一种常见的皮肤病,特别是在青少年期间,由于激素水平的变化导致皮脂腺过度活跃,容易堵塞毛孔,形成痤疮。 病理护理建议:1.保持面部清洁,每天洗脸2-3次,使用温和的洁面产品。 2.避免用手挤压痤疮,以免加重炎症或留下疤痕。 3.避免使用刺激性的化妆品和护肤品。 4.可以使用含有水杨酸、苯氧醇等成分的外用药物治疗。 5.如有需要,可以在医生指导下使用抗生素口服药或其他治疗方法。 注意事项:1. 避免摩擦或损伤患处,以免引起感染。 2. 饮食上应少吃油腻、辛辣食物,多吃蔬菜水果。 3. 保持良好的作息习惯,避免熬夜。 4. 如果痤疮症状持续不见好转或有恶化的趋势,应及时就医。
|
| 227 |
"""
|
| 228 |
-
|
| 229 |
-
return prompt
|
| 230 |
-
|
| 231 |
def _parse_refined_output(
|
| 232 |
-
self,
|
| 233 |
-
generated_text: str,
|
| 234 |
raw_answer: str,
|
| 235 |
raw_thinking: Optional[str] = None,
|
| 236 |
-
language: str = "zh"
|
| 237 |
) -> dict:
|
| 238 |
-
"""
|
| 239 |
-
Parse DeepSeek generated output
|
| 240 |
-
|
| 241 |
-
Parameters:
|
| 242 |
-
generated_text: DeepSeek generated text
|
| 243 |
-
raw_answer: Original diagnosis (as fallback)
|
| 244 |
-
raw_thinking: Original thinking process (as fallback)
|
| 245 |
-
language: Language option
|
| 246 |
-
|
| 247 |
-
Returns:
|
| 248 |
-
Dictionary containing description, analysis_process and diagnosis_result
|
| 249 |
-
"""
|
| 250 |
description = ""
|
| 251 |
analysis_process = None
|
| 252 |
diagnosis_result = None
|
| 253 |
-
|
| 254 |
if language == "en":
|
| 255 |
-
# English patterns
|
| 256 |
desc_match = re.search(
|
| 257 |
-
r
|
| 258 |
generated_text,
|
| 259 |
-
re.IGNORECASE
|
| 260 |
)
|
| 261 |
analysis_match = re.search(
|
| 262 |
-
r
|
| 263 |
generated_text,
|
| 264 |
-
re.IGNORECASE
|
| 265 |
)
|
| 266 |
result_match = re.search(
|
| 267 |
-
r
|
| 268 |
generated_text,
|
| 269 |
-
re.IGNORECASE
|
| 270 |
)
|
| 271 |
-
|
| 272 |
desc_header = "## Description"
|
| 273 |
analysis_header = "## Analysis Process"
|
| 274 |
result_header = "## Diagnosis Result"
|
| 275 |
else:
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
generated_text
|
| 280 |
-
)
|
| 281 |
-
analysis_match = re.search(
|
| 282 |
-
r'##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)',
|
| 283 |
-
generated_text
|
| 284 |
-
)
|
| 285 |
-
result_match = re.search(
|
| 286 |
-
r'##\s*诊断结果\s*\n([\s\S]*?)$',
|
| 287 |
-
generated_text
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
desc_header = "## 图像描述"
|
| 291 |
analysis_header = "## 分析过程"
|
| 292 |
result_header = "## 诊断结果"
|
| 293 |
-
|
| 294 |
-
# Extract description
|
| 295 |
if desc_match:
|
| 296 |
description = desc_match.group(1).strip()
|
| 297 |
-
print(f"Successfully parsed description")
|
| 298 |
else:
|
| 299 |
-
print(f"Description parsing failed")
|
| 300 |
description = ""
|
| 301 |
-
|
| 302 |
-
# Extract analysis process
|
| 303 |
if analysis_match:
|
| 304 |
analysis_process = analysis_match.group(1).strip()
|
| 305 |
-
print(f"Successfully parsed analysis process")
|
| 306 |
else:
|
| 307 |
-
print(f"Analysis process parsing failed, trying other methods")
|
| 308 |
-
# Try to extract from generated text
|
| 309 |
result_pos = generated_text.find(result_header)
|
| 310 |
if result_pos > 0:
|
| 311 |
-
# Get content before diagnosis result
|
| 312 |
analysis_process = generated_text[:result_pos].strip()
|
| 313 |
-
# Remove possible headers
|
| 314 |
for header in [desc_header, analysis_header]:
|
| 315 |
-
|
| 316 |
-
analysis_process = re.sub(f'{header_escaped}\\s*\\n?', '', analysis_process).strip()
|
| 317 |
else:
|
| 318 |
-
|
| 319 |
-
mid_point = len(generated_text) // 2
|
| 320 |
-
analysis_process = generated_text[:mid_point].strip()
|
| 321 |
-
|
| 322 |
-
# If still empty, use original content (final fallback)
|
| 323 |
if not analysis_process and raw_thinking:
|
| 324 |
-
print(f"Using original raw_thinking as fallback")
|
| 325 |
analysis_process = raw_thinking
|
| 326 |
-
|
| 327 |
-
# Extract diagnosis result
|
| 328 |
if result_match:
|
| 329 |
diagnosis_result = result_match.group(1).strip()
|
| 330 |
-
print(f"Successfully parsed diagnosis result")
|
| 331 |
else:
|
| 332 |
-
print(f"Diagnosis result parsing failed, trying other methods")
|
| 333 |
-
# Try to extract from generated text
|
| 334 |
result_pos = generated_text.find(result_header)
|
| 335 |
if result_pos > 0:
|
| 336 |
diagnosis_result = generated_text[result_pos:].strip()
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
| 340 |
else:
|
| 341 |
-
|
| 342 |
-
mid_point = len(generated_text) // 2
|
| 343 |
-
diagnosis_result = generated_text[mid_point:].strip()
|
| 344 |
-
|
| 345 |
-
# If still empty, use original content (final fallback)
|
| 346 |
if not diagnosis_result:
|
| 347 |
-
print(f"Using original raw_answer as fallback")
|
| 348 |
diagnosis_result = raw_answer
|
| 349 |
-
|
| 350 |
return {
|
| 351 |
"description": description,
|
| 352 |
"analysis_process": analysis_process,
|
| 353 |
-
"diagnosis_result": diagnosis_result
|
| 354 |
}
|
| 355 |
|
| 356 |
|
| 357 |
-
# Global DeepSeek service instance (lazy loading)
|
| 358 |
_deepseek_service: Optional[DeepSeekService] = None
|
| 359 |
|
| 360 |
|
| 361 |
async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
|
| 362 |
-
"""
|
| 363 |
-
Get DeepSeek service instance (singleton pattern)
|
| 364 |
-
|
| 365 |
-
Parameters:
|
| 366 |
-
api_key: Optional API key to use
|
| 367 |
-
|
| 368 |
-
Returns:
|
| 369 |
-
DeepSeekService instance, or None if API initialization fails
|
| 370 |
-
"""
|
| 371 |
global _deepseek_service
|
| 372 |
-
|
| 373 |
if _deepseek_service is None:
|
| 374 |
try:
|
| 375 |
_deepseek_service = DeepSeekService(api_key=api_key)
|
| 376 |
await _deepseek_service.load()
|
| 377 |
if not _deepseek_service.is_loaded:
|
| 378 |
print("DeepSeek API service initialization failed, will use fallback mode")
|
| 379 |
-
return _deepseek_service
|
| 380 |
-
except Exception as
|
| 381 |
-
print(f"DeepSeek service initialization failed: {
|
| 382 |
return None
|
| 383 |
-
|
| 384 |
return _deepseek_service
|
|
|
|
| 1 |
+
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
from typing import Optional
|
| 6 |
+
|
| 7 |
from openai import AsyncOpenAI
|
| 8 |
|
| 9 |
|
| 10 |
class DeepSeekService:
|
| 11 |
+
"""OpenAI-compatible DeepSeek refinement service."""
|
| 12 |
+
|
| 13 |
def __init__(self, api_key: Optional[str] = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
self.api_key = api_key or os.environ.get("DEEPSEEK_API_KEY")
|
| 15 |
self.base_url = "https://api.deepseek.com"
|
| 16 |
+
self.model = "deepseek-chat"
|
|
|
|
| 17 |
self.client = None
|
| 18 |
self.is_loaded = False
|
| 19 |
+
|
| 20 |
+
print("DeepSeek API service initializing...")
|
| 21 |
print(f"API Base URL: {self.base_url}")
|
| 22 |
+
|
| 23 |
async def load(self):
|
|
|
|
| 24 |
try:
|
| 25 |
if not self.api_key:
|
| 26 |
print("DeepSeek API key not provided")
|
| 27 |
self.is_loaded = False
|
| 28 |
return
|
| 29 |
+
|
| 30 |
+
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
self.is_loaded = True
|
| 32 |
print("DeepSeek API service is ready!")
|
| 33 |
+
except Exception as exc:
|
| 34 |
+
print(f"DeepSeek API service initialization failed: {exc}")
|
|
|
|
| 35 |
self.is_loaded = False
|
| 36 |
+
|
| 37 |
async def refine_diagnosis(
|
| 38 |
+
self,
|
| 39 |
raw_answer: str,
|
| 40 |
raw_thinking: Optional[str] = None,
|
| 41 |
+
language: str = "zh",
|
| 42 |
) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if not self.is_loaded or self.client is None:
|
| 44 |
+
error_msg = (
|
| 45 |
+
"API not initialized, cannot generate analysis"
|
| 46 |
+
if language == "en"
|
| 47 |
+
else "API未初始化,无法生成分析过程"
|
| 48 |
+
)
|
| 49 |
print("DeepSeek API not initialized, returning original result")
|
| 50 |
return {
|
| 51 |
"success": False,
|
|
|
|
| 53 |
"analysis_process": raw_thinking or error_msg,
|
| 54 |
"diagnosis_result": raw_answer,
|
| 55 |
"original_diagnosis": raw_answer,
|
| 56 |
+
"error": "DeepSeek API not initialized",
|
| 57 |
}
|
| 58 |
+
|
| 59 |
try:
|
|
|
|
| 60 |
prompt = self._build_refine_prompt(raw_answer, raw_thinking, language)
|
| 61 |
+
system_content = (
|
| 62 |
+
"You are a professional medical text editor. Your task is to polish and organize "
|
| 63 |
+
"medical diagnostic text to make it flow smoothly while preserving the original "
|
| 64 |
+
"meaning. Output ONLY the formatted result. Do NOT add any explanations, comments, "
|
| 65 |
+
"or thoughts. Just follow the format exactly."
|
| 66 |
+
if language == "en"
|
| 67 |
+
else "你是医学文本整理专家,按照用户要求将用户输入的文本整理成用户想要的格式,不要改写或总结。"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
response = await self.client.chat.completions.create(
|
| 71 |
model=self.model,
|
| 72 |
messages=[
|
| 73 |
{"role": "system", "content": system_content},
|
| 74 |
+
{"role": "user", "content": prompt},
|
| 75 |
],
|
| 76 |
temperature=0.1,
|
| 77 |
max_tokens=2048,
|
| 78 |
top_p=0.8,
|
| 79 |
)
|
| 80 |
+
|
|
|
|
| 81 |
generated_text = response.choices[0].message.content
|
|
|
|
|
|
|
| 82 |
parsed = self._parse_refined_output(generated_text, raw_answer, raw_thinking, language)
|
| 83 |
+
|
| 84 |
return {
|
| 85 |
"success": True,
|
| 86 |
"description": parsed["description"],
|
| 87 |
"analysis_process": parsed["analysis_process"],
|
| 88 |
"diagnosis_result": parsed["diagnosis_result"],
|
| 89 |
"original_diagnosis": raw_answer,
|
| 90 |
+
"raw_refined": generated_text,
|
| 91 |
}
|
| 92 |
+
except Exception as exc:
|
| 93 |
+
print(f"DeepSeek API call failed: {exc}")
|
| 94 |
+
error_msg = (
|
| 95 |
+
"API call failed, cannot generate analysis"
|
| 96 |
+
if language == "en"
|
| 97 |
+
else "API调用失败,无法生成分析过程"
|
| 98 |
+
)
|
| 99 |
return {
|
| 100 |
"success": False,
|
| 101 |
"description": "",
|
| 102 |
"analysis_process": raw_thinking or error_msg,
|
| 103 |
"diagnosis_result": raw_answer,
|
| 104 |
"original_diagnosis": raw_answer,
|
| 105 |
+
"error": str(exc),
|
| 106 |
}
|
| 107 |
+
|
| 108 |
+
def _build_refine_prompt(
|
| 109 |
+
self,
|
| 110 |
+
raw_answer: str,
|
| 111 |
+
raw_thinking: Optional[str] = None,
|
| 112 |
+
language: str = "zh",
|
| 113 |
+
) -> str:
|
| 114 |
+
thinking_text = raw_thinking if raw_thinking else "No analysis process available."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
if language == "en":
|
| 116 |
+
return f"""You are a text organization expert. There are two texts that need to be organized. Text 1 is the thinking process of the SkinGPT model, and Text 2 is the diagnosis result given by SkinGPT.
|
|
|
|
|
|
|
| 117 |
|
| 118 |
【Requirements】
|
| 119 |
- Preserve the original tone and expression style
|
|
|
|
| 140 |
|
| 141 |
## Diagnosis Result
|
| 142 |
(The organized diagnosis result from Text 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
"""
|
| 144 |
+
|
| 145 |
+
return f"""你是一个文本整理专家。有两段文本需要整理,文本1是SkinGPT模型的思考过程的文本,文本2是SkinGPT给出的诊断结果的文本。
|
|
|
|
|
|
|
| 146 |
|
| 147 |
【要求】
|
| 148 |
- 保留原文的语气和表达方式
|
|
|
|
| 155 |
- 禁止推断或添加新的医学信息,禁止输出任何元评论
|
| 156 |
- 可以调整不合理的语句或去除冗余内容以提高清晰度
|
| 157 |
|
|
|
|
| 158 |
【文本1】
|
| 159 |
{thinking_text}
|
| 160 |
|
|
|
|
| 170 |
|
| 171 |
## 诊断结果
|
| 172 |
(整理后的诊断结果)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
"""
|
| 174 |
+
|
|
|
|
|
|
|
| 175 |
def _parse_refined_output(
|
| 176 |
+
self,
|
| 177 |
+
generated_text: str,
|
| 178 |
raw_answer: str,
|
| 179 |
raw_thinking: Optional[str] = None,
|
| 180 |
+
language: str = "zh",
|
| 181 |
) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
description = ""
|
| 183 |
analysis_process = None
|
| 184 |
diagnosis_result = None
|
| 185 |
+
|
| 186 |
if language == "en":
|
|
|
|
| 187 |
desc_match = re.search(
|
| 188 |
+
r"##\s*Description\s*\n([\s\S]*?)(?=##\s*Analysis\s*Process|$)",
|
| 189 |
generated_text,
|
| 190 |
+
re.IGNORECASE,
|
| 191 |
)
|
| 192 |
analysis_match = re.search(
|
| 193 |
+
r"##\s*Analysis\s*Process\s*\n([\s\S]*?)(?=##\s*Diagnosis\s*Result|$)",
|
| 194 |
generated_text,
|
| 195 |
+
re.IGNORECASE,
|
| 196 |
)
|
| 197 |
result_match = re.search(
|
| 198 |
+
r"##\s*Diagnosis\s*Result\s*\n([\s\S]*?)$",
|
| 199 |
generated_text,
|
| 200 |
+
re.IGNORECASE,
|
| 201 |
)
|
|
|
|
| 202 |
desc_header = "## Description"
|
| 203 |
analysis_header = "## Analysis Process"
|
| 204 |
result_header = "## Diagnosis Result"
|
| 205 |
else:
|
| 206 |
+
desc_match = re.search(r"##\s*图像描述\s*\n([\s\S]*?)(?=##\s*分析过程|$)", generated_text)
|
| 207 |
+
analysis_match = re.search(r"##\s*分析过程\s*\n([\s\S]*?)(?=##\s*诊断结果|$)", generated_text)
|
| 208 |
+
result_match = re.search(r"##\s*诊断结果\s*\n([\s\S]*?)$", generated_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
desc_header = "## 图像描述"
|
| 210 |
analysis_header = "## 分析过程"
|
| 211 |
result_header = "## 诊断结果"
|
| 212 |
+
|
|
|
|
| 213 |
if desc_match:
|
| 214 |
description = desc_match.group(1).strip()
|
|
|
|
| 215 |
else:
|
|
|
|
| 216 |
description = ""
|
| 217 |
+
|
|
|
|
| 218 |
if analysis_match:
|
| 219 |
analysis_process = analysis_match.group(1).strip()
|
|
|
|
| 220 |
else:
|
|
|
|
|
|
|
| 221 |
result_pos = generated_text.find(result_header)
|
| 222 |
if result_pos > 0:
|
|
|
|
| 223 |
analysis_process = generated_text[:result_pos].strip()
|
|
|
|
| 224 |
for header in [desc_header, analysis_header]:
|
| 225 |
+
analysis_process = re.sub(f"{re.escape(header)}\\s*\\n?", "", analysis_process).strip()
|
|
|
|
| 226 |
else:
|
| 227 |
+
analysis_process = generated_text[: len(generated_text) // 2].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
if not analysis_process and raw_thinking:
|
|
|
|
| 229 |
analysis_process = raw_thinking
|
| 230 |
+
|
|
|
|
| 231 |
if result_match:
|
| 232 |
diagnosis_result = result_match.group(1).strip()
|
|
|
|
| 233 |
else:
|
|
|
|
|
|
|
| 234 |
result_pos = generated_text.find(result_header)
|
| 235 |
if result_pos > 0:
|
| 236 |
diagnosis_result = generated_text[result_pos:].strip()
|
| 237 |
+
diagnosis_result = re.sub(
|
| 238 |
+
f"^{re.escape(result_header)}\\s*\\n?",
|
| 239 |
+
"",
|
| 240 |
+
diagnosis_result,
|
| 241 |
+
).strip()
|
| 242 |
else:
|
| 243 |
+
diagnosis_result = generated_text[len(generated_text) // 2 :].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
if not diagnosis_result:
|
|
|
|
| 245 |
diagnosis_result = raw_answer
|
| 246 |
+
|
| 247 |
return {
|
| 248 |
"description": description,
|
| 249 |
"analysis_process": analysis_process,
|
| 250 |
+
"diagnosis_result": diagnosis_result,
|
| 251 |
}
|
| 252 |
|
| 253 |
|
|
|
|
| 254 |
_deepseek_service: Optional[DeepSeekService] = None
|
| 255 |
|
| 256 |
|
| 257 |
async def get_deepseek_service(api_key: Optional[str] = None) -> Optional[DeepSeekService]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
global _deepseek_service
|
| 259 |
+
|
| 260 |
if _deepseek_service is None:
|
| 261 |
try:
|
| 262 |
_deepseek_service = DeepSeekService(api_key=api_key)
|
| 263 |
await _deepseek_service.load()
|
| 264 |
if not _deepseek_service.is_loaded:
|
| 265 |
print("DeepSeek API service initialization failed, will use fallback mode")
|
| 266 |
+
return _deepseek_service
|
| 267 |
+
except Exception as exc:
|
| 268 |
+
print(f"DeepSeek service initialization failed: {exc}")
|
| 269 |
return None
|
| 270 |
+
|
| 271 |
return _deepseek_service
|
inference/full_precision/demo.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from .model_utils import (
|
| 7 |
+
DEFAULT_MODEL_PATH,
|
| 8 |
+
SkinGPTModel,
|
| 9 |
+
build_single_turn_messages,
|
| 10 |
+
resolve_model_path,
|
| 11 |
+
)
|
| 12 |
+
except ImportError:
|
| 13 |
+
from model_utils import (
|
| 14 |
+
DEFAULT_MODEL_PATH,
|
| 15 |
+
SkinGPTModel,
|
| 16 |
+
build_single_turn_messages,
|
| 17 |
+
resolve_model_path,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
IMAGE_PATH = "test_image.jpg"
|
| 21 |
+
PROMPT = "Please analyze this skin image and provide a diagnosis."
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main() -> None:
|
| 25 |
+
if not Path(IMAGE_PATH).exists():
|
| 26 |
+
print(f"Warning: Image not found at '{IMAGE_PATH}'. Please edit IMAGE_PATH in demo.py")
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
model = SkinGPTModel(resolve_model_path(DEFAULT_MODEL_PATH))
|
| 30 |
+
messages = build_single_turn_messages(IMAGE_PATH, PROMPT)
|
| 31 |
+
|
| 32 |
+
print("Processing...")
|
| 33 |
+
output_text = model.generate_response(messages)
|
| 34 |
+
|
| 35 |
+
print("\n=== Diagnosis Result ===")
|
| 36 |
+
print(output_text)
|
| 37 |
+
print("========================")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
inference/full_precision/infer.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from .model_utils import (
|
| 8 |
+
DEFAULT_MODEL_PATH,
|
| 9 |
+
SkinGPTModel,
|
| 10 |
+
build_single_turn_messages,
|
| 11 |
+
resolve_model_path,
|
| 12 |
+
)
|
| 13 |
+
except ImportError:
|
| 14 |
+
from model_utils import (
|
| 15 |
+
DEFAULT_MODEL_PATH,
|
| 16 |
+
SkinGPTModel,
|
| 17 |
+
build_single_turn_messages,
|
| 18 |
+
resolve_model_path,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 23 |
+
parser = argparse.ArgumentParser(description="SkinGPT-R1 full-precision single inference")
|
| 24 |
+
parser.add_argument("--image", type=str, required=True, help="Path to the image")
|
| 25 |
+
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--prompt",
|
| 28 |
+
type=str,
|
| 29 |
+
default="Please analyze this skin image and provide a diagnosis.",
|
| 30 |
+
)
|
| 31 |
+
return parser
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main() -> None:
|
| 35 |
+
args = build_parser().parse_args()
|
| 36 |
+
|
| 37 |
+
if not Path(args.image).exists():
|
| 38 |
+
print(f"Error: Image not found at {args.image}")
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
model = SkinGPTModel(resolve_model_path(args.model_path))
|
| 42 |
+
messages = build_single_turn_messages(args.image, args.prompt)
|
| 43 |
+
|
| 44 |
+
print(f"\nAnalyzing {args.image}...")
|
| 45 |
+
response = model.generate_response(messages)
|
| 46 |
+
|
| 47 |
+
print("-" * 40)
|
| 48 |
+
print("Result:")
|
| 49 |
+
print(response)
|
| 50 |
+
print("-" * 40)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
main()
|
inference/{model_utils.py → full_precision/model_utils.py}
RENAMED
|
@@ -1,51 +1,96 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
-
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
|
| 4 |
from qwen_vl_utils import process_vision_info
|
| 5 |
-
from
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class SkinGPTModel:
|
| 10 |
-
def __init__(self, model_path, device=None):
|
| 11 |
-
|
|
|
|
| 12 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
-
print(f"Loading model from {
|
| 14 |
-
|
| 15 |
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 16 |
-
|
| 17 |
torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
|
| 18 |
attn_implementation="flash_attention_2" if self.device == "cuda" else None,
|
| 19 |
device_map="auto" if self.device != "mps" else None,
|
| 20 |
-
trust_remote_code=True
|
| 21 |
)
|
| 22 |
-
|
| 23 |
if self.device == "mps":
|
| 24 |
self.model = self.model.to(self.device)
|
| 25 |
|
| 26 |
self.processor = AutoProcessor.from_pretrained(
|
| 27 |
-
|
| 28 |
-
trust_remote_code=True,
|
| 29 |
-
min_pixels=256*28*28,
|
| 30 |
-
max_pixels=1280*28*28
|
| 31 |
)
|
| 32 |
print("Model loaded successfully.")
|
| 33 |
|
| 34 |
-
def generate_response(
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
image_inputs, video_inputs = process_vision_info(messages)
|
| 48 |
-
|
| 49 |
inputs = self.processor(
|
| 50 |
text=[text],
|
| 51 |
images=image_inputs,
|
|
@@ -62,30 +107,35 @@ class SkinGPTModel:
|
|
| 62 |
repetition_penalty=repetition_penalty,
|
| 63 |
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 64 |
top_p=0.9,
|
| 65 |
-
do_sample=True
|
| 66 |
)
|
| 67 |
|
| 68 |
-
# 解码输出 (去除输入的token)
|
| 69 |
generated_ids_trimmed = [
|
| 70 |
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 71 |
]
|
| 72 |
output_text = self.processor.batch_decode(
|
| 73 |
-
generated_ids_trimmed,
|
|
|
|
|
|
|
| 74 |
)
|
| 75 |
-
|
| 76 |
return output_text[0]
|
| 77 |
-
|
| 78 |
-
def generate_response_stream(
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
image_inputs, video_inputs = process_vision_info(messages)
|
| 88 |
-
|
| 89 |
inputs = self.processor(
|
| 90 |
text=[text],
|
| 91 |
images=image_inputs,
|
|
@@ -93,15 +143,13 @@ class SkinGPTModel:
|
|
| 93 |
padding=True,
|
| 94 |
return_tensors="pt",
|
| 95 |
).to(self.model.device)
|
| 96 |
-
|
| 97 |
-
# 创建 TextIteratorStreamer 用于流式输出
|
| 98 |
streamer = TextIteratorStreamer(
|
| 99 |
self.processor.tokenizer,
|
| 100 |
skip_prompt=True,
|
| 101 |
-
skip_special_tokens=True
|
| 102 |
)
|
| 103 |
-
|
| 104 |
-
# 准备生成参数
|
| 105 |
generation_kwargs = {
|
| 106 |
**inputs,
|
| 107 |
"max_new_tokens": max_new_tokens,
|
|
@@ -112,13 +160,11 @@ class SkinGPTModel:
|
|
| 112 |
"do_sample": True,
|
| 113 |
"streamer": streamer,
|
| 114 |
}
|
| 115 |
-
|
| 116 |
-
# 在单独的线程中运行生成
|
| 117 |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
| 118 |
thread.start()
|
| 119 |
-
|
| 120 |
-
# 逐个yield生成的文本
|
| 121 |
for text_chunk in streamer:
|
| 122 |
yield text_chunk
|
| 123 |
-
|
| 124 |
-
thread.join()
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from threading import Thread
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
import torch
|
|
|
|
| 8 |
from qwen_vl_utils import process_vision_info
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoProcessor,
|
| 11 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 12 |
+
TextIteratorStreamer,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
DEFAULT_MODEL_PATH = "./checkpoints/full_precision"
|
| 16 |
+
DEFAULT_SYSTEM_PROMPT = "You are a professional AI dermatology assistant."
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def resolve_model_path(model_path: str = DEFAULT_MODEL_PATH) -> str:
|
| 20 |
+
"""Resolve a model path for both cloned-repo and local-dev layouts."""
|
| 21 |
+
raw_path = Path(model_path).expanduser()
|
| 22 |
+
repo_root = Path(__file__).resolve().parents[2]
|
| 23 |
+
candidates = [raw_path]
|
| 24 |
+
|
| 25 |
+
if not raw_path.is_absolute():
|
| 26 |
+
candidates.append(Path.cwd() / raw_path)
|
| 27 |
+
candidates.append(repo_root / raw_path)
|
| 28 |
+
if raw_path.parts and raw_path.parts[0] == repo_root.name:
|
| 29 |
+
candidates.append(repo_root.joinpath(*raw_path.parts[1:]))
|
| 30 |
+
|
| 31 |
+
for candidate in candidates:
|
| 32 |
+
if candidate.exists():
|
| 33 |
+
return str(candidate)
|
| 34 |
+
return str(raw_path)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_single_turn_messages(
|
| 38 |
+
image_path: str,
|
| 39 |
+
prompt: str,
|
| 40 |
+
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
| 41 |
+
) -> List[dict]:
|
| 42 |
+
return [
|
| 43 |
+
{
|
| 44 |
+
"role": "user",
|
| 45 |
+
"content": [
|
| 46 |
+
{"type": "image", "image": image_path},
|
| 47 |
+
{"type": "text", "text": f"{system_prompt}\n\n{prompt}"},
|
| 48 |
+
],
|
| 49 |
+
}
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
|
| 53 |
class SkinGPTModel:
|
| 54 |
+
def __init__(self, model_path: str = DEFAULT_MODEL_PATH, device: str | None = None):
|
| 55 |
+
resolved_model_path = resolve_model_path(model_path)
|
| 56 |
+
self.model_path = resolved_model_path
|
| 57 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
print(f"Loading model from {resolved_model_path} on {self.device}...")
|
| 59 |
+
|
| 60 |
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 61 |
+
resolved_model_path,
|
| 62 |
torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
|
| 63 |
attn_implementation="flash_attention_2" if self.device == "cuda" else None,
|
| 64 |
device_map="auto" if self.device != "mps" else None,
|
| 65 |
+
trust_remote_code=True,
|
| 66 |
)
|
| 67 |
+
|
| 68 |
if self.device == "mps":
|
| 69 |
self.model = self.model.to(self.device)
|
| 70 |
|
| 71 |
self.processor = AutoProcessor.from_pretrained(
|
| 72 |
+
resolved_model_path,
|
| 73 |
+
trust_remote_code=True,
|
| 74 |
+
min_pixels=256 * 28 * 28,
|
| 75 |
+
max_pixels=1280 * 28 * 28,
|
| 76 |
)
|
| 77 |
print("Model loaded successfully.")
|
| 78 |
|
| 79 |
+
def generate_response(
|
| 80 |
+
self,
|
| 81 |
+
messages,
|
| 82 |
+
max_new_tokens: int = 1024,
|
| 83 |
+
temperature: float = 0.7,
|
| 84 |
+
repetition_penalty: float = 1.2,
|
| 85 |
+
no_repeat_ngram_size: int = 3,
|
| 86 |
+
) -> str:
|
| 87 |
+
text = self.processor.apply_chat_template(
|
| 88 |
+
messages,
|
| 89 |
+
tokenize=False,
|
| 90 |
+
add_generation_prompt=True,
|
| 91 |
+
)
|
| 92 |
image_inputs, video_inputs = process_vision_info(messages)
|
| 93 |
+
|
| 94 |
inputs = self.processor(
|
| 95 |
text=[text],
|
| 96 |
images=image_inputs,
|
|
|
|
| 107 |
repetition_penalty=repetition_penalty,
|
| 108 |
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 109 |
top_p=0.9,
|
| 110 |
+
do_sample=True,
|
| 111 |
)
|
| 112 |
|
|
|
|
| 113 |
generated_ids_trimmed = [
|
| 114 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 115 |
]
|
| 116 |
output_text = self.processor.batch_decode(
|
| 117 |
+
generated_ids_trimmed,
|
| 118 |
+
skip_special_tokens=True,
|
| 119 |
+
clean_up_tokenization_spaces=False,
|
| 120 |
)
|
| 121 |
+
|
| 122 |
return output_text[0]
|
| 123 |
+
|
| 124 |
+
def generate_response_stream(
|
| 125 |
+
self,
|
| 126 |
+
messages,
|
| 127 |
+
max_new_tokens: int = 1024,
|
| 128 |
+
temperature: float = 0.7,
|
| 129 |
+
repetition_penalty: float = 1.2,
|
| 130 |
+
no_repeat_ngram_size: int = 3,
|
| 131 |
+
):
|
| 132 |
+
text = self.processor.apply_chat_template(
|
| 133 |
+
messages,
|
| 134 |
+
tokenize=False,
|
| 135 |
+
add_generation_prompt=True,
|
| 136 |
+
)
|
| 137 |
image_inputs, video_inputs = process_vision_info(messages)
|
| 138 |
+
|
| 139 |
inputs = self.processor(
|
| 140 |
text=[text],
|
| 141 |
images=image_inputs,
|
|
|
|
| 143 |
padding=True,
|
| 144 |
return_tensors="pt",
|
| 145 |
).to(self.model.device)
|
| 146 |
+
|
|
|
|
| 147 |
streamer = TextIteratorStreamer(
|
| 148 |
self.processor.tokenizer,
|
| 149 |
skip_prompt=True,
|
| 150 |
+
skip_special_tokens=True,
|
| 151 |
)
|
| 152 |
+
|
|
|
|
| 153 |
generation_kwargs = {
|
| 154 |
**inputs,
|
| 155 |
"max_new_tokens": max_new_tokens,
|
|
|
|
| 160 |
"do_sample": True,
|
| 161 |
"streamer": streamer,
|
| 162 |
}
|
| 163 |
+
|
|
|
|
| 164 |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
| 165 |
thread.start()
|
| 166 |
+
|
|
|
|
| 167 |
for text_chunk in streamer:
|
| 168 |
yield text_chunk
|
| 169 |
+
|
| 170 |
+
thread.join()
|
inference/full_precision/run_api.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
PYTHON_EXE="${PYTHON_EXE:-python}"
|
| 5 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 6 |
+
"${PYTHON_EXE}" "${SCRIPT_DIR}/app.py"
|
inference/full_precision/run_chat.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
PYTHON_EXE="${PYTHON_EXE:-python}"
|
| 5 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 6 |
+
"${PYTHON_EXE}" "${SCRIPT_DIR}/chat.py" "$@"
|
inference/full_precision/run_infer.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
PYTHON_EXE="${PYTHON_EXE:-python}"
|
| 5 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 6 |
+
"${PYTHON_EXE}" "${SCRIPT_DIR}/infer.py" "$@"
|
inference/inference.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import argparse
|
| 3 |
-
from model_utils import SkinGPTModel
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
def main():
|
| 7 |
-
parser = argparse.ArgumentParser(description="SkinGPT-R1 Single Inference")
|
| 8 |
-
parser.add_argument("--image", type=str, required=True, help="Path to the image")
|
| 9 |
-
parser.add_argument("--model_path", type=str, default="../checkpoint")
|
| 10 |
-
parser.add_argument("--prompt", type=str, default="Please analyze this skin image and provide a diagnosis.")
|
| 11 |
-
args = parser.parse_args()
|
| 12 |
-
|
| 13 |
-
if not os.path.exists(args.image):
|
| 14 |
-
print(f"Error: Image not found at {args.image}")
|
| 15 |
-
return
|
| 16 |
-
|
| 17 |
-
# 1. 加载模型 (复用 model_utils)
|
| 18 |
-
# 这样你就不用在这里重复写 transformers 的加载代码了
|
| 19 |
-
bot = SkinGPTModel(args.model_path)
|
| 20 |
-
|
| 21 |
-
# 2. 构造单轮消息
|
| 22 |
-
system_prompt = "You are a professional AI dermatology assistant."
|
| 23 |
-
messages = [
|
| 24 |
-
{
|
| 25 |
-
"role": "user",
|
| 26 |
-
"content": [
|
| 27 |
-
{"type": "image", "image": args.image},
|
| 28 |
-
{"type": "text", "text": f"{system_prompt}\n\n{args.prompt}"}
|
| 29 |
-
]
|
| 30 |
-
}
|
| 31 |
-
]
|
| 32 |
-
|
| 33 |
-
# 3. 推理
|
| 34 |
-
print(f"\nAnalyzing {args.image}...")
|
| 35 |
-
response = bot.generate_response(messages)
|
| 36 |
-
|
| 37 |
-
print("-" * 40)
|
| 38 |
-
print("Result:")
|
| 39 |
-
print(response)
|
| 40 |
-
print("-" * 40)
|
| 41 |
-
|
| 42 |
-
if __name__ == "__main__":
|
| 43 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/int4_quantized/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""INT4 quantized inference package for SkinGPT-R1."""
|
inference/int4_quantized/__pycache__/app.cpython-311.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|
inference/int4_quantized/__pycache__/chat.cpython-311.pyc
ADDED
|
Binary file (3.49 kB). View file
|
|
|
inference/int4_quantized/__pycache__/infer.cpython-311.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
inference/int4_quantized/__pycache__/model_utils.cpython-311.pyc
ADDED
|
Binary file (28.9 kB). View file
|
|
|
inference/{.ipynb_checkpoints/app-checkpoint.py → int4_quantized/app.py}
RENAMED
|
@@ -1,133 +1,97 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
import shutil
|
|
|
|
| 5 |
import uuid
|
| 6 |
-
import json
|
| 7 |
-
import re
|
| 8 |
-
import asyncio
|
| 9 |
-
from typing import Optional
|
| 10 |
-
from io import BytesIO
|
| 11 |
from contextlib import asynccontextmanager
|
| 12 |
-
from
|
| 13 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
from fastapi.responses import StreamingResponse
|
| 16 |
-
from
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
-
# Global DeepSeek service instance
|
| 29 |
deepseek_service: Optional[DeepSeekService] = None
|
| 30 |
|
| 31 |
-
@asynccontextmanager
|
| 32 |
-
async def lifespan(app: FastAPI):
|
| 33 |
-
"""应用生命周期管理"""
|
| 34 |
-
# 启动时初始化 DeepSeek 服务
|
| 35 |
-
await init_deepseek()
|
| 36 |
-
yield
|
| 37 |
-
print("\nShutting down service...")
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
description="智能皮肤诊断助手",
|
| 42 |
-
version="1.0.0",
|
| 43 |
-
lifespan=lifespan
|
| 44 |
-
)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
CORSMiddleware,
|
| 49 |
-
allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
|
| 50 |
-
allow_credentials=True,
|
| 51 |
-
allow_methods=["*"],
|
| 52 |
-
allow_headers=["*"],
|
| 53 |
-
)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
# pending_images: 存储已上传但尚未发送给LLM的图片路径 (State ID -> Image Path)
|
| 58 |
-
chat_states = {}
|
| 59 |
-
pending_images = {}
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
解析诊断结果中的think和answer标签
|
| 64 |
-
|
| 65 |
-
参数:
|
| 66 |
-
- raw_text: 原始诊断文本
|
| 67 |
-
|
| 68 |
-
返回:
|
| 69 |
-
- dict: 包含thinking, answer, raw字段的字典
|
| 70 |
-
"""
|
| 71 |
-
import re
|
| 72 |
-
|
| 73 |
-
# 尝试匹配完整的标签
|
| 74 |
-
think_match = re.search(r'<think>([\s\S]*?)</think>', raw_text)
|
| 75 |
-
answer_match = re.search(r'<answer>([\s\S]*?)</answer>', raw_text)
|
| 76 |
-
|
| 77 |
-
thinking = None
|
| 78 |
-
answer = None
|
| 79 |
-
|
| 80 |
-
# 处理think标签
|
| 81 |
-
if think_match:
|
| 82 |
-
thinking = think_match.group(1).strip()
|
| 83 |
-
else:
|
| 84 |
-
# 尝试匹配未闭合的think标签(输出被截断的情况)
|
| 85 |
-
unclosed_think = re.search(r'<think>([\s\S]*?)(?=<answer>|$)', raw_text)
|
| 86 |
if unclosed_think:
|
| 87 |
thinking = unclosed_think.group(1).strip()
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
answer = answer_match.group(1).strip()
|
| 92 |
-
else:
|
| 93 |
-
# 尝试匹配未闭合的answer标签
|
| 94 |
-
unclosed_answer = re.search(r'<answer>([\s\S]*?)$', raw_text)
|
| 95 |
if unclosed_answer:
|
| 96 |
answer = unclosed_answer.group(1).strip()
|
| 97 |
-
|
| 98 |
-
# 如果仍然没有找到answer,清理原始文本作为answer
|
| 99 |
if not answer:
|
| 100 |
-
|
| 101 |
-
cleaned = re.sub(r
|
| 102 |
-
cleaned = re.sub(r
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
answer = cleaned if cleaned else raw_text
|
| 106 |
-
|
| 107 |
-
# 清理可能残留的标签
|
| 108 |
-
if answer:
|
| 109 |
-
answer = re.sub(r'</?think>|</?answer>', '', answer).strip()
|
| 110 |
-
if thinking:
|
| 111 |
-
thinking = re.sub(r'</?think>|</?answer>', '', thinking).strip()
|
| 112 |
-
|
| 113 |
-
# 处理 "Final Answer:" 格式,提取其后的内容
|
| 114 |
if answer:
|
| 115 |
-
|
|
|
|
| 116 |
if final_answer_match:
|
| 117 |
answer = final_answer_match.group(1).strip()
|
| 118 |
-
|
| 119 |
-
return {
|
| 120 |
-
"thinking": thinking if thinking else None,
|
| 121 |
-
"answer": answer,
|
| 122 |
-
"raw": raw_text
|
| 123 |
-
}
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
# 初始化 DeepSeek 服务(异步)
|
| 131 |
async def init_deepseek():
|
| 132 |
global deepseek_service
|
| 133 |
print("\nInitializing DeepSeek service...")
|
|
@@ -137,120 +101,115 @@ async def init_deepseek():
|
|
| 137 |
else:
|
| 138 |
print("DeepSeek service not available, will return raw results")
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
@app.post("/v1/upload/{state_id}")
|
| 141 |
async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
|
| 142 |
-
|
| 143 |
-
接收图片上传。
|
| 144 |
-
逻辑:将图片保存到本地临时目录,并标记该 state_id 有一张待处理图片。
|
| 145 |
-
"""
|
| 146 |
try:
|
| 147 |
-
# 1. 保存图片到本地临时文件
|
| 148 |
file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
|
| 149 |
unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
|
| 150 |
-
file_path =
|
| 151 |
-
|
| 152 |
-
with open(
|
| 153 |
shutil.copyfileobj(file.file, buffer)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
# 如果是多图模式,这里可以改成 list,目前演示单图覆盖或更新
|
| 157 |
-
pending_images[state_id] = file_path
|
| 158 |
-
|
| 159 |
-
# 3. 初始化对话状态(如果是新会话)
|
| 160 |
if state_id not in chat_states:
|
| 161 |
chat_states[state_id] = []
|
| 162 |
-
|
| 163 |
-
return {"message": "Image uploaded successfully", "path": file_path}
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
|
| 168 |
@app.post("/v1/predict/{state_id}")
|
| 169 |
async def v1_predict(request: Request, state_id: str):
|
| 170 |
-
"""
|
| 171 |
-
接收文本并执行推理。
|
| 172 |
-
逻辑:检查是否有待处理图片。如果有,将其与文本组合成 multimodal 消息。
|
| 173 |
-
"""
|
| 174 |
try:
|
| 175 |
data = await request.json()
|
| 176 |
-
except:
|
| 177 |
-
raise HTTPException(status_code=400, detail="Invalid JSON")
|
| 178 |
-
|
| 179 |
user_message = data.get("message", "")
|
| 180 |
if not user_message:
|
| 181 |
raise HTTPException(status_code=400, detail="Missing 'message' field")
|
| 182 |
|
| 183 |
-
# 获取或初始化历史
|
| 184 |
history = chat_states.get(state_id, [])
|
| 185 |
-
|
| 186 |
-
# 构建当前轮次的用户内容
|
| 187 |
current_content = []
|
| 188 |
-
|
| 189 |
-
# 1. 检查是否有刚刚上传的图片
|
| 190 |
if state_id in pending_images:
|
| 191 |
-
img_path = pending_images.pop(state_id)
|
| 192 |
current_content.append({"type": "image", "image": img_path})
|
| 193 |
-
|
| 194 |
-
# ��果是第一次对话,加上 System Prompt
|
| 195 |
if not history:
|
| 196 |
-
|
| 197 |
-
user_message = f"{system_prompt}\n\n{user_message}"
|
| 198 |
|
| 199 |
-
# 2. 添加文本
|
| 200 |
current_content.append({"type": "text", "text": user_message})
|
| 201 |
-
|
| 202 |
-
# 3. 更新历史
|
| 203 |
history.append({"role": "user", "content": current_content})
|
| 204 |
chat_states[state_id] = history
|
| 205 |
|
| 206 |
-
# 4. 运行推理 (在线程池中运行以防阻塞)
|
| 207 |
try:
|
| 208 |
-
response_text = await run_in_threadpool(
|
| 209 |
-
|
| 210 |
-
messages=history
|
| 211 |
-
)
|
| 212 |
-
except Exception as e:
|
| 213 |
-
# 回滚历史(移除刚才出错的用户提问)
|
| 214 |
chat_states[state_id].pop()
|
| 215 |
-
raise HTTPException(status_code=500, detail=f"Inference error: {
|
| 216 |
|
| 217 |
-
# 5. 将回复加入历史
|
| 218 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
|
| 219 |
chat_states[state_id] = history
|
| 220 |
-
|
| 221 |
return {"message": response_text}
|
| 222 |
|
|
|
|
| 223 |
@app.post("/v1/reset/{state_id}")
|
| 224 |
async def reset_chat(state_id: str):
|
| 225 |
-
"""清除会话状态"""
|
| 226 |
if state_id in chat_states:
|
| 227 |
del chat_states[state_id]
|
| 228 |
if state_id in pending_images:
|
| 229 |
-
# 可选:删除临时文件
|
| 230 |
try:
|
| 231 |
-
|
| 232 |
-
except:
|
| 233 |
pass
|
| 234 |
del pending_images[state_id]
|
| 235 |
return {"message": "Chat history reset"}
|
| 236 |
|
|
|
|
| 237 |
@app.get("/")
|
| 238 |
async def root():
|
| 239 |
-
"""根路径"""
|
| 240 |
return {
|
| 241 |
-
"name": "SkinGPT-R1
|
| 242 |
-
"version": "1.
|
| 243 |
"status": "running",
|
| 244 |
-
"description": "
|
| 245 |
}
|
| 246 |
|
|
|
|
| 247 |
@app.get("/health")
|
| 248 |
async def health_check():
|
| 249 |
-
"""
|
| 250 |
-
|
| 251 |
-
"status": "healthy",
|
| 252 |
-
"model_loaded": True
|
| 253 |
-
}
|
| 254 |
|
| 255 |
@app.post("/diagnose/stream")
|
| 256 |
async def diagnose_stream(
|
|
@@ -258,126 +217,89 @@ async def diagnose_stream(
|
|
| 258 |
text: str = Form(...),
|
| 259 |
language: str = Form("zh"),
|
| 260 |
):
|
| 261 |
-
"""
|
| 262 |
-
SSE流式诊断接口(用于前端)
|
| 263 |
-
支持图片上传和文本输入,返回真正的流式响应
|
| 264 |
-
使用 DeepSeek API 优化输出格式
|
| 265 |
-
"""
|
| 266 |
-
from queue import Queue, Empty
|
| 267 |
-
from threading import Thread
|
| 268 |
-
|
| 269 |
language = language if language in ("zh", "en") else "zh"
|
| 270 |
-
|
| 271 |
-
# 处理图片
|
| 272 |
pil_image = None
|
| 273 |
-
|
| 274 |
-
|
| 275 |
if image:
|
| 276 |
contents = await image.read()
|
| 277 |
pil_image = Image.open(BytesIO(contents)).convert("RGB")
|
| 278 |
-
|
| 279 |
-
# 创建队列用于线程间通信
|
| 280 |
result_queue = Queue()
|
| 281 |
-
# 用于存储完整响应和解析结果
|
| 282 |
generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
|
| 283 |
-
|
| 284 |
def run_generation():
|
| 285 |
-
"""在后台线程中运行流式生成"""
|
| 286 |
full_response = []
|
| 287 |
-
|
| 288 |
try:
|
| 289 |
-
# 构建消息
|
| 290 |
messages = []
|
| 291 |
current_content = []
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
| 297 |
if pil_image:
|
| 298 |
-
|
| 299 |
-
pil_image.save(
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
current_content.append({"type": "text", "text": prompt})
|
| 305 |
messages.append({"role": "user", "content": current_content})
|
| 306 |
-
|
| 307 |
-
# 流式生成 - 每个 chunk 立即���入队列
|
| 308 |
for chunk in gpt_model.generate_response_stream(
|
| 309 |
messages=messages,
|
| 310 |
-
max_new_tokens=
|
| 311 |
-
|
|
|
|
| 312 |
):
|
| 313 |
full_response.append(chunk)
|
| 314 |
result_queue.put(("delta", chunk))
|
| 315 |
-
|
| 316 |
-
# 解析结果
|
| 317 |
response_text = "".join(full_response)
|
| 318 |
-
parsed = parse_diagnosis_result(response_text)
|
| 319 |
generation_result["full_response"] = full_response
|
| 320 |
-
generation_result["parsed"] =
|
| 321 |
-
|
| 322 |
-
# 标记生成完成
|
| 323 |
result_queue.put(("generation_done", None))
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
async def event_generator():
|
| 329 |
-
"""异步生成SSE事件"""
|
| 330 |
-
# 在后台线程启动生成(非阻塞)
|
| 331 |
gen_thread = Thread(target=run_generation)
|
| 332 |
gen_thread.start()
|
| 333 |
-
|
| 334 |
loop = asyncio.get_event_loop()
|
| 335 |
-
|
| 336 |
-
# 从队列中读取并发送流式内容
|
| 337 |
while True:
|
| 338 |
try:
|
| 339 |
-
# 非阻塞获取
|
| 340 |
msg_type, data = await loop.run_in_executor(
|
| 341 |
-
None,
|
| 342 |
-
lambda: result_queue.get(timeout=0.1)
|
| 343 |
)
|
| 344 |
-
|
| 345 |
if msg_type == "generation_done":
|
| 346 |
-
# 流式生成完成,准备处理最终结果
|
| 347 |
break
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
yield f"data: {yield_chunk}\n\n"
|
| 351 |
elif msg_type == "error":
|
| 352 |
yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
|
| 353 |
gen_thread.join()
|
| 354 |
return
|
| 355 |
-
|
| 356 |
except Empty:
|
| 357 |
-
# 队列暂时为空,继续等待
|
| 358 |
await asyncio.sleep(0.01)
|
| 359 |
-
|
| 360 |
-
|
| 361 |
gen_thread.join()
|
| 362 |
-
|
| 363 |
-
# 获取解析结果
|
| 364 |
parsed = generation_result["parsed"]
|
| 365 |
if not parsed:
|
| 366 |
-
yield
|
| 367 |
return
|
| 368 |
-
|
| 369 |
raw_thinking = parsed["thinking"]
|
| 370 |
raw_answer = parsed["answer"]
|
| 371 |
-
|
| 372 |
-
# 使用 DeepSeek 优化结果
|
| 373 |
refined_by_deepseek = False
|
| 374 |
description = None
|
| 375 |
thinking = raw_thinking
|
| 376 |
answer = raw_answer
|
| 377 |
-
|
| 378 |
if deepseek_service and deepseek_service.is_loaded:
|
| 379 |
try:
|
| 380 |
-
print(f"Calling DeepSeek to refine diagnosis (language={language})...")
|
| 381 |
refined = await deepseek_service.refine_diagnosis(
|
| 382 |
raw_answer=raw_answer,
|
| 383 |
raw_thinking=raw_thinking,
|
|
@@ -388,36 +310,35 @@ async def diagnose_stream(
|
|
| 388 |
thinking = refined["analysis_process"]
|
| 389 |
answer = refined["diagnosis_result"]
|
| 390 |
refined_by_deepseek = True
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
print(f"DeepSeek refinement failed, using original: {e}")
|
| 394 |
else:
|
| 395 |
print("DeepSeek service not available, using raw results")
|
| 396 |
-
|
| 397 |
-
success_msg = "Diagnosis completed" if language == "en" else "诊断完成"
|
| 398 |
-
|
| 399 |
-
# 返回格式与参考项目保持一致
|
| 400 |
final_payload = {
|
| 401 |
-
"description": description,
|
| 402 |
-
"thinking": thinking,
|
| 403 |
-
"answer": answer,
|
| 404 |
-
"raw": parsed["raw"],
|
| 405 |
-
"refined_by_deepseek": refined_by_deepseek,
|
| 406 |
"success": True,
|
| 407 |
-
"message":
|
| 408 |
}
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
# 清理临时图片
|
| 413 |
temp_path = generation_result.get("temp_image_path")
|
| 414 |
-
if temp_path
|
| 415 |
try:
|
| 416 |
-
|
| 417 |
-
except:
|
| 418 |
pass
|
| 419 |
-
|
| 420 |
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
| 421 |
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
import os
|
| 6 |
import shutil
|
| 7 |
+
import sys
|
| 8 |
import uuid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from contextlib import asynccontextmanager
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from queue import Empty, Queue
|
| 13 |
+
from threading import Thread
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
import uvicorn
|
| 17 |
+
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
| 18 |
+
from fastapi.concurrency import run_in_threadpool
|
| 19 |
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
from fastapi.responses import StreamingResponse
|
| 21 |
+
from PIL import Image
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from .model_utils import (
|
| 25 |
+
DEFAULT_DO_SAMPLE,
|
| 26 |
+
DEFAULT_MAX_NEW_TOKENS,
|
| 27 |
+
DEFAULT_MODEL_PATH,
|
| 28 |
+
DEFAULT_REPETITION_PENALTY,
|
| 29 |
+
QuantizedSkinGPTModel,
|
| 30 |
+
)
|
| 31 |
+
except ImportError:
|
| 32 |
+
from model_utils import (
|
| 33 |
+
DEFAULT_DO_SAMPLE,
|
| 34 |
+
DEFAULT_MAX_NEW_TOKENS,
|
| 35 |
+
DEFAULT_MODEL_PATH,
|
| 36 |
+
DEFAULT_REPETITION_PENALTY,
|
| 37 |
+
QuantizedSkinGPTModel,
|
| 38 |
+
)
|
| 39 |
|
| 40 |
+
try:
|
| 41 |
+
from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service
|
| 42 |
+
except ImportError:
|
| 43 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
| 44 |
+
from inference.full_precision.deepseek_service import DeepSeekService, get_deepseek_service
|
| 45 |
|
| 46 |
+
TEMP_DIR = Path(__file__).resolve().parents[1] / "temp_uploads"
|
| 47 |
+
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
DEEPSEEK_API_KEY = os.environ.get("DEEPSEEK_API_KEY")
|
| 49 |
|
|
|
|
| 50 |
deepseek_service: Optional[DeepSeekService] = None
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
def parse_diagnosis_result(raw_text: str) -> dict:
|
| 54 |
+
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
think_match = re.search(r"<think>([\s\S]*?)</think>", raw_text)
|
| 57 |
+
answer_match = re.search(r"<answer>([\s\S]*?)</answer>", raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
thinking = think_match.group(1).strip() if think_match else None
|
| 60 |
+
answer = answer_match.group(1).strip() if answer_match else None
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
if not thinking:
|
| 63 |
+
unclosed_think = re.search(r"<think>([\s\S]*?)(?=<answer>|$)", raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if unclosed_think:
|
| 65 |
thinking = unclosed_think.group(1).strip()
|
| 66 |
+
|
| 67 |
+
if not answer:
|
| 68 |
+
unclosed_answer = re.search(r"<answer>([\s\S]*?)$", raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
if unclosed_answer:
|
| 70 |
answer = unclosed_answer.group(1).strip()
|
| 71 |
+
|
|
|
|
| 72 |
if not answer:
|
| 73 |
+
cleaned = re.sub(r"<think>[\s\S]*?</think>", "", raw_text)
|
| 74 |
+
cleaned = re.sub(r"<think>[\s\S]*", "", cleaned)
|
| 75 |
+
cleaned = re.sub(r"</?answer>", "", cleaned)
|
| 76 |
+
answer = cleaned.strip() or raw_text
|
| 77 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if answer:
|
| 79 |
+
answer = re.sub(r"</?think>|</?answer>", "", answer).strip()
|
| 80 |
+
final_answer_match = re.search(r"Final Answer:\s*([\s\S]*)", answer, re.IGNORECASE)
|
| 81 |
if final_answer_match:
|
| 82 |
answer = final_answer_match.group(1).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
if thinking:
|
| 85 |
+
thinking = re.sub(r"</?think>|</?answer>", "", thinking).strip()
|
| 86 |
+
|
| 87 |
+
return {"thinking": thinking or None, "answer": answer, "raw": raw_text}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
print("Initializing INT4 Model Service...")
|
| 91 |
+
gpt_model = QuantizedSkinGPTModel(DEFAULT_MODEL_PATH)
|
| 92 |
+
print("INT4 service ready.")
|
| 93 |
+
|
| 94 |
|
|
|
|
| 95 |
async def init_deepseek():
|
| 96 |
global deepseek_service
|
| 97 |
print("\nInitializing DeepSeek service...")
|
|
|
|
| 101 |
else:
|
| 102 |
print("DeepSeek service not available, will return raw results")
|
| 103 |
|
| 104 |
+
|
| 105 |
+
@asynccontextmanager
|
| 106 |
+
async def lifespan(app: FastAPI):
|
| 107 |
+
await init_deepseek()
|
| 108 |
+
yield
|
| 109 |
+
print("\nShutting down INT4 service...")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
app = FastAPI(
|
| 113 |
+
title="SkinGPT-R1 INT4 API",
|
| 114 |
+
description="INT4 quantized dermatology assistant backend",
|
| 115 |
+
version="1.1.0",
|
| 116 |
+
lifespan=lifespan,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
app.add_middleware(
|
| 120 |
+
CORSMiddleware,
|
| 121 |
+
allow_origins=["http://localhost:3000", "http://localhost:5173", "http://127.0.0.1:5173", "*"],
|
| 122 |
+
allow_credentials=True,
|
| 123 |
+
allow_methods=["*"],
|
| 124 |
+
allow_headers=["*"],
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
chat_states = {}
|
| 128 |
+
pending_images = {}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
@app.post("/v1/upload/{state_id}")
|
| 132 |
async def upload_file(state_id: str, file: UploadFile = File(...), survey: str = Form(None)):
|
| 133 |
+
del survey
|
|
|
|
|
|
|
|
|
|
| 134 |
try:
|
|
|
|
| 135 |
file_extension = file.filename.split(".")[-1] if "." in file.filename else "jpg"
|
| 136 |
unique_name = f"{state_id}_{uuid.uuid4().hex}.{file_extension}"
|
| 137 |
+
file_path = TEMP_DIR / unique_name
|
| 138 |
+
|
| 139 |
+
with file_path.open("wb") as buffer:
|
| 140 |
shutil.copyfileobj(file.file, buffer)
|
| 141 |
+
|
| 142 |
+
pending_images[state_id] = str(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
if state_id not in chat_states:
|
| 144 |
chat_states[state_id] = []
|
| 145 |
+
|
| 146 |
+
return {"message": "Image uploaded successfully", "path": str(file_path)}
|
| 147 |
+
except Exception as exc:
|
| 148 |
+
raise HTTPException(status_code=500, detail=f"Upload failed: {exc}") from exc
|
| 149 |
+
|
| 150 |
|
| 151 |
@app.post("/v1/predict/{state_id}")
|
| 152 |
async def v1_predict(request: Request, state_id: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
try:
|
| 154 |
data = await request.json()
|
| 155 |
+
except Exception as exc:
|
| 156 |
+
raise HTTPException(status_code=400, detail="Invalid JSON") from exc
|
| 157 |
+
|
| 158 |
user_message = data.get("message", "")
|
| 159 |
if not user_message:
|
| 160 |
raise HTTPException(status_code=400, detail="Missing 'message' field")
|
| 161 |
|
|
|
|
| 162 |
history = chat_states.get(state_id, [])
|
|
|
|
|
|
|
| 163 |
current_content = []
|
| 164 |
+
|
|
|
|
| 165 |
if state_id in pending_images:
|
| 166 |
+
img_path = pending_images.pop(state_id)
|
| 167 |
current_content.append({"type": "image", "image": img_path})
|
|
|
|
|
|
|
| 168 |
if not history:
|
| 169 |
+
user_message = f"You are a professional AI dermatology assistant.\n\n{user_message}"
|
|
|
|
| 170 |
|
|
|
|
| 171 |
current_content.append({"type": "text", "text": user_message})
|
|
|
|
|
|
|
| 172 |
history.append({"role": "user", "content": current_content})
|
| 173 |
chat_states[state_id] = history
|
| 174 |
|
|
|
|
| 175 |
try:
|
| 176 |
+
response_text = await run_in_threadpool(gpt_model.generate_response, messages=history)
|
| 177 |
+
except Exception as exc:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
chat_states[state_id].pop()
|
| 179 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {exc}") from exc
|
| 180 |
|
|
|
|
| 181 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response_text}]})
|
| 182 |
chat_states[state_id] = history
|
|
|
|
| 183 |
return {"message": response_text}
|
| 184 |
|
| 185 |
+
|
| 186 |
@app.post("/v1/reset/{state_id}")
|
| 187 |
async def reset_chat(state_id: str):
|
|
|
|
| 188 |
if state_id in chat_states:
|
| 189 |
del chat_states[state_id]
|
| 190 |
if state_id in pending_images:
|
|
|
|
| 191 |
try:
|
| 192 |
+
Path(pending_images[state_id]).unlink(missing_ok=True)
|
| 193 |
+
except Exception:
|
| 194 |
pass
|
| 195 |
del pending_images[state_id]
|
| 196 |
return {"message": "Chat history reset"}
|
| 197 |
|
| 198 |
+
|
| 199 |
@app.get("/")
|
| 200 |
async def root():
|
|
|
|
| 201 |
return {
|
| 202 |
+
"name": "SkinGPT-R1 INT4 API",
|
| 203 |
+
"version": "1.1.0",
|
| 204 |
"status": "running",
|
| 205 |
+
"description": "INT4 quantized dermatology assistant",
|
| 206 |
}
|
| 207 |
|
| 208 |
+
|
| 209 |
@app.get("/health")
|
| 210 |
async def health_check():
|
| 211 |
+
return {"status": "healthy", "model_loaded": True}
|
| 212 |
+
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
@app.post("/diagnose/stream")
|
| 215 |
async def diagnose_stream(
|
|
|
|
| 217 |
text: str = Form(...),
|
| 218 |
language: str = Form("zh"),
|
| 219 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
language = language if language in ("zh", "en") else "zh"
|
|
|
|
|
|
|
| 221 |
pil_image = None
|
| 222 |
+
|
|
|
|
| 223 |
if image:
|
| 224 |
contents = await image.read()
|
| 225 |
pil_image = Image.open(BytesIO(contents)).convert("RGB")
|
| 226 |
+
|
|
|
|
| 227 |
result_queue = Queue()
|
|
|
|
| 228 |
generation_result = {"full_response": [], "parsed": None, "temp_image_path": None}
|
| 229 |
+
|
| 230 |
def run_generation():
|
|
|
|
| 231 |
full_response = []
|
|
|
|
| 232 |
try:
|
|
|
|
| 233 |
messages = []
|
| 234 |
current_content = []
|
| 235 |
+
system_prompt = (
|
| 236 |
+
"You are a professional AI dermatology assistant."
|
| 237 |
+
if language == "en"
|
| 238 |
+
else "你是一个专业的AI皮肤科助手。"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
if pil_image:
|
| 242 |
+
temp_image_path = TEMP_DIR / f"temp_{uuid.uuid4().hex}.jpg"
|
| 243 |
+
pil_image.save(temp_image_path)
|
| 244 |
+
generation_result["temp_image_path"] = str(temp_image_path)
|
| 245 |
+
current_content.append({"type": "image", "image": str(temp_image_path)})
|
| 246 |
+
|
| 247 |
+
current_content.append({"type": "text", "text": f"{system_prompt}\n\n{text}"})
|
|
|
|
| 248 |
messages.append({"role": "user", "content": current_content})
|
| 249 |
+
|
|
|
|
| 250 |
for chunk in gpt_model.generate_response_stream(
|
| 251 |
messages=messages,
|
| 252 |
+
max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
|
| 253 |
+
do_sample=DEFAULT_DO_SAMPLE,
|
| 254 |
+
repetition_penalty=DEFAULT_REPETITION_PENALTY,
|
| 255 |
):
|
| 256 |
full_response.append(chunk)
|
| 257 |
result_queue.put(("delta", chunk))
|
| 258 |
+
|
|
|
|
| 259 |
response_text = "".join(full_response)
|
|
|
|
| 260 |
generation_result["full_response"] = full_response
|
| 261 |
+
generation_result["parsed"] = parse_diagnosis_result(response_text)
|
|
|
|
|
|
|
| 262 |
result_queue.put(("generation_done", None))
|
| 263 |
+
except Exception as exc:
|
| 264 |
+
result_queue.put(("error", str(exc)))
|
| 265 |
+
|
|
|
|
| 266 |
async def event_generator():
|
|
|
|
|
|
|
| 267 |
gen_thread = Thread(target=run_generation)
|
| 268 |
gen_thread.start()
|
| 269 |
+
|
| 270 |
loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
| 271 |
while True:
|
| 272 |
try:
|
|
|
|
| 273 |
msg_type, data = await loop.run_in_executor(
|
| 274 |
+
None,
|
| 275 |
+
lambda: result_queue.get(timeout=0.1),
|
| 276 |
)
|
|
|
|
| 277 |
if msg_type == "generation_done":
|
|
|
|
| 278 |
break
|
| 279 |
+
if msg_type == "delta":
|
| 280 |
+
yield f"data: {json.dumps({'type': 'delta', 'text': data}, ensure_ascii=False)}\n\n"
|
|
|
|
| 281 |
elif msg_type == "error":
|
| 282 |
yield f"data: {json.dumps({'type': 'error', 'message': data}, ensure_ascii=False)}\n\n"
|
| 283 |
gen_thread.join()
|
| 284 |
return
|
|
|
|
| 285 |
except Empty:
|
|
|
|
| 286 |
await asyncio.sleep(0.01)
|
| 287 |
+
|
|
|
|
| 288 |
gen_thread.join()
|
|
|
|
|
|
|
| 289 |
parsed = generation_result["parsed"]
|
| 290 |
if not parsed:
|
| 291 |
+
yield "data: {\"type\": \"error\", \"message\": \"Failed to parse response\"}\n\n"
|
| 292 |
return
|
| 293 |
+
|
| 294 |
raw_thinking = parsed["thinking"]
|
| 295 |
raw_answer = parsed["answer"]
|
|
|
|
|
|
|
| 296 |
refined_by_deepseek = False
|
| 297 |
description = None
|
| 298 |
thinking = raw_thinking
|
| 299 |
answer = raw_answer
|
| 300 |
+
|
| 301 |
if deepseek_service and deepseek_service.is_loaded:
|
| 302 |
try:
|
|
|
|
| 303 |
refined = await deepseek_service.refine_diagnosis(
|
| 304 |
raw_answer=raw_answer,
|
| 305 |
raw_thinking=raw_thinking,
|
|
|
|
| 310 |
thinking = refined["analysis_process"]
|
| 311 |
answer = refined["diagnosis_result"]
|
| 312 |
refined_by_deepseek = True
|
| 313 |
+
except Exception as exc:
|
| 314 |
+
print(f"DeepSeek refinement failed, using original: {exc}")
|
|
|
|
| 315 |
else:
|
| 316 |
print("DeepSeek service not available, using raw results")
|
| 317 |
+
|
|
|
|
|
|
|
|
|
|
| 318 |
final_payload = {
|
| 319 |
+
"description": description,
|
| 320 |
+
"thinking": thinking,
|
| 321 |
+
"answer": answer,
|
| 322 |
+
"raw": parsed["raw"],
|
| 323 |
+
"refined_by_deepseek": refined_by_deepseek,
|
| 324 |
"success": True,
|
| 325 |
+
"message": "Diagnosis completed" if language == "en" else "诊断完成",
|
| 326 |
}
|
| 327 |
+
yield f"data: {json.dumps({'type': 'final', 'result': final_payload}, ensure_ascii=False)}\n\n"
|
| 328 |
+
|
|
|
|
|
|
|
| 329 |
temp_path = generation_result.get("temp_image_path")
|
| 330 |
+
if temp_path:
|
| 331 |
try:
|
| 332 |
+
Path(temp_path).unlink(missing_ok=True)
|
| 333 |
+
except Exception:
|
| 334 |
pass
|
| 335 |
+
|
| 336 |
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
| 337 |
|
| 338 |
+
|
| 339 |
+
def main() -> None:
|
| 340 |
+
uvicorn.run("app:app", host="0.0.0.0", port=5901, reload=False)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
main()
|
inference/{.ipynb_checkpoints/chat-checkpoint.py → int4_quantized/chat.py}
RENAMED
|
@@ -1,48 +1,51 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
import argparse
|
| 3 |
-
import
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
def
|
| 7 |
-
parser = argparse.ArgumentParser(description="SkinGPT-R1
|
| 8 |
-
parser.add_argument("--model_path", type=str, default=
|
| 9 |
parser.add_argument("--image", type=str, required=True, help="Path to initial image")
|
| 10 |
-
|
| 11 |
|
| 12 |
-
# 初始化模型
|
| 13 |
-
bot = SkinGPTModel(args.model_path)
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
# 构造第一条包含图片的消息
|
| 20 |
-
if not os.path.exists(args.image):
|
| 21 |
print(f"Error: Image {args.image} not found.")
|
| 22 |
return
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
]
|
| 31 |
-
}
|
| 32 |
-
]
|
| 33 |
|
| 34 |
-
print("\n=== SkinGPT-R1 Chat (Type 'exit' to quit) ===")
|
| 35 |
print(f"Image loaded: {args.image}")
|
| 36 |
-
|
| 37 |
-
# 获取第一轮诊断
|
| 38 |
print("\nModel is thinking...", end="", flush=True)
|
| 39 |
-
response =
|
| 40 |
print(f"\rAssistant: {response}\n")
|
| 41 |
-
|
| 42 |
-
# 将助手的回复加入历史
|
| 43 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 44 |
|
| 45 |
-
# 进入多轮对话循环
|
| 46 |
while True:
|
| 47 |
try:
|
| 48 |
user_input = input("User: ")
|
|
@@ -51,18 +54,14 @@ def main():
|
|
| 51 |
if not user_input.strip():
|
| 52 |
continue
|
| 53 |
|
| 54 |
-
# 加入用户的新问题
|
| 55 |
history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
|
| 56 |
-
|
| 57 |
print("Model is thinking...", end="", flush=True)
|
| 58 |
-
response =
|
| 59 |
print(f"\rAssistant: {response}\n")
|
| 60 |
-
|
| 61 |
-
# 加入助手的新回复
|
| 62 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 63 |
-
|
| 64 |
except KeyboardInterrupt:
|
| 65 |
break
|
| 66 |
|
|
|
|
| 67 |
if __name__ == "__main__":
|
| 68 |
-
main()
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from .model_utils import (
|
| 8 |
+
DEFAULT_MODEL_PATH,
|
| 9 |
+
QuantizedSkinGPTModel,
|
| 10 |
+
build_single_turn_messages,
|
| 11 |
+
)
|
| 12 |
+
except ImportError:
|
| 13 |
+
from model_utils import (
|
| 14 |
+
DEFAULT_MODEL_PATH,
|
| 15 |
+
QuantizedSkinGPTModel,
|
| 16 |
+
build_single_turn_messages,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
|
| 20 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 21 |
+
parser = argparse.ArgumentParser(description="SkinGPT-R1 INT4 multi-turn chat")
|
| 22 |
+
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
|
| 23 |
parser.add_argument("--image", type=str, required=True, help="Path to initial image")
|
| 24 |
+
return parser
|
| 25 |
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
def main() -> None:
|
| 28 |
+
args = build_parser().parse_args()
|
| 29 |
+
|
| 30 |
+
if not Path(args.image).exists():
|
|
|
|
|
|
|
| 31 |
print(f"Error: Image {args.image} not found.")
|
| 32 |
return
|
| 33 |
|
| 34 |
+
model = QuantizedSkinGPTModel(args.model_path)
|
| 35 |
+
history = build_single_turn_messages(
|
| 36 |
+
args.image,
|
| 37 |
+
"Please analyze this image.",
|
| 38 |
+
system_prompt="You are a professional AI dermatology assistant. Analyze the skin condition carefully.",
|
| 39 |
+
)
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
print("\n=== SkinGPT-R1 INT4 Chat (Type 'exit' to quit) ===")
|
| 42 |
print(f"Image loaded: {args.image}")
|
| 43 |
+
|
|
|
|
| 44 |
print("\nModel is thinking...", end="", flush=True)
|
| 45 |
+
response = model.generate_response(history)
|
| 46 |
print(f"\rAssistant: {response}\n")
|
|
|
|
|
|
|
| 47 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
| 48 |
|
|
|
|
| 49 |
while True:
|
| 50 |
try:
|
| 51 |
user_input = input("User: ")
|
|
|
|
| 54 |
if not user_input.strip():
|
| 55 |
continue
|
| 56 |
|
|
|
|
| 57 |
history.append({"role": "user", "content": [{"type": "text", "text": user_input}]})
|
|
|
|
| 58 |
print("Model is thinking...", end="", flush=True)
|
| 59 |
+
response = model.generate_response(history)
|
| 60 |
print(f"\rAssistant: {response}\n")
|
|
|
|
|
|
|
| 61 |
history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
|
|
|
|
| 62 |
except KeyboardInterrupt:
|
| 63 |
break
|
| 64 |
|
| 65 |
+
|
| 66 |
if __name__ == "__main__":
|
| 67 |
+
main()
|
inference/int4_quantized/infer.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from .model_utils import (
|
| 9 |
+
DEFAULT_DO_SAMPLE,
|
| 10 |
+
DEFAULT_MODEL_PATH,
|
| 11 |
+
DEFAULT_MAX_NEW_TOKENS,
|
| 12 |
+
DEFAULT_PROMPT,
|
| 13 |
+
DEFAULT_REPETITION_PENALTY,
|
| 14 |
+
DEFAULT_TEMPERATURE,
|
| 15 |
+
DEFAULT_TOP_P,
|
| 16 |
+
QuantizedSkinGPTModel,
|
| 17 |
+
build_single_turn_messages,
|
| 18 |
+
)
|
| 19 |
+
except ImportError:
|
| 20 |
+
from model_utils import (
|
| 21 |
+
DEFAULT_DO_SAMPLE,
|
| 22 |
+
DEFAULT_MODEL_PATH,
|
| 23 |
+
DEFAULT_MAX_NEW_TOKENS,
|
| 24 |
+
DEFAULT_PROMPT,
|
| 25 |
+
DEFAULT_REPETITION_PENALTY,
|
| 26 |
+
DEFAULT_TEMPERATURE,
|
| 27 |
+
DEFAULT_TOP_P,
|
| 28 |
+
QuantizedSkinGPTModel,
|
| 29 |
+
build_single_turn_messages,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 34 |
+
parser = argparse.ArgumentParser(description="SkinGPT-R1 INT4 inference")
|
| 35 |
+
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
|
| 36 |
+
parser.add_argument("--image_path", type=str, required=True, help="Path to the test image")
|
| 37 |
+
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt for diagnosis")
|
| 38 |
+
parser.add_argument("--max_new_tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS)
|
| 39 |
+
parser.add_argument("--do_sample", action="store_true", default=DEFAULT_DO_SAMPLE)
|
| 40 |
+
parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE)
|
| 41 |
+
parser.add_argument("--top_p", type=float, default=DEFAULT_TOP_P)
|
| 42 |
+
parser.add_argument("--repetition_penalty", type=float, default=DEFAULT_REPETITION_PENALTY)
|
| 43 |
+
return parser
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def main() -> None:
|
| 47 |
+
args = build_parser().parse_args()
|
| 48 |
+
|
| 49 |
+
if not Path(args.image_path).exists():
|
| 50 |
+
print(f"Error: Image not found at {args.image_path}")
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
print("=== [1] Initializing INT4 Quantization ===")
|
| 54 |
+
print("BitsAndBytesConfig will be applied during model loading.")
|
| 55 |
+
|
| 56 |
+
print("=== [2] Loading Model and Processor ===")
|
| 57 |
+
start_load = time.time()
|
| 58 |
+
model = QuantizedSkinGPTModel(args.model_path)
|
| 59 |
+
print(f"Model loaded in {time.time() - start_load:.2f} seconds.")
|
| 60 |
+
|
| 61 |
+
print("=== [3] Preparing Input ===")
|
| 62 |
+
messages = build_single_turn_messages(args.image_path, args.prompt)
|
| 63 |
+
|
| 64 |
+
print("=== [4] Generating Response ===")
|
| 65 |
+
start_infer = time.time()
|
| 66 |
+
output_text = model.generate_response(
|
| 67 |
+
messages,
|
| 68 |
+
max_new_tokens=args.max_new_tokens,
|
| 69 |
+
do_sample=args.do_sample,
|
| 70 |
+
temperature=args.temperature,
|
| 71 |
+
top_p=args.top_p,
|
| 72 |
+
repetition_penalty=args.repetition_penalty,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
print(f"Inference completed in {time.time() - start_infer:.2f} seconds.")
|
| 76 |
+
print("\n================ MODEL OUTPUT ================\n")
|
| 77 |
+
print(output_text)
|
| 78 |
+
print("\n==============================================\n")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
main()
|
inference/int4_quantized/model_utils.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from threading import Thread
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from qwen_vl_utils import process_vision_info
|
| 11 |
+
from transformers import (
|
| 12 |
+
AutoProcessor,
|
| 13 |
+
BitsAndBytesConfig,
|
| 14 |
+
StoppingCriteria,
|
| 15 |
+
StoppingCriteriaList,
|
| 16 |
+
TextIteratorStreamer,
|
| 17 |
+
)
|
| 18 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
| 19 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
DEFAULT_MODEL_PATH = "./checkpoints/int4"
|
| 23 |
+
DEFAULT_SYSTEM_PROMPT = (
|
| 24 |
+
"You are a professional AI dermatology assistant. "
|
| 25 |
+
"Reason step by step, keep the reasoning concise, avoid repetition, "
|
| 26 |
+
"and always finish with <answer>...</answer>."
|
| 27 |
+
)
|
| 28 |
+
DEFAULT_MAX_NEW_TOKENS = 768
|
| 29 |
+
DEFAULT_CONTINUE_TOKENS = 256
|
| 30 |
+
DEFAULT_DO_SAMPLE = False
|
| 31 |
+
DEFAULT_TEMPERATURE = 0.2
|
| 32 |
+
DEFAULT_TOP_P = 0.9
|
| 33 |
+
DEFAULT_REPETITION_PENALTY = 1.15
|
| 34 |
+
DEFAULT_NO_REPEAT_NGRAM_SIZE = 3
|
| 35 |
+
DEFAULT_PROMPT = (
|
| 36 |
+
"Act as a dermatologist. Analyze the visual features of this skin lesion "
|
| 37 |
+
"step by step, and provide a final diagnosis."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def resolve_model_path(model_path: str = DEFAULT_MODEL_PATH) -> str:
|
| 42 |
+
raw_path = Path(model_path).expanduser()
|
| 43 |
+
repo_root = Path(__file__).resolve().parents[2]
|
| 44 |
+
candidates = [raw_path]
|
| 45 |
+
|
| 46 |
+
if not raw_path.is_absolute():
|
| 47 |
+
candidates.append(Path.cwd() / raw_path)
|
| 48 |
+
candidates.append(repo_root / raw_path)
|
| 49 |
+
if raw_path.parts and raw_path.parts[0] == repo_root.name:
|
| 50 |
+
candidates.append(repo_root.joinpath(*raw_path.parts[1:]))
|
| 51 |
+
|
| 52 |
+
for candidate in candidates:
|
| 53 |
+
if candidate.exists():
|
| 54 |
+
return str(candidate)
|
| 55 |
+
return str(raw_path)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_single_turn_messages(
|
| 59 |
+
image_path: str,
|
| 60 |
+
prompt: str,
|
| 61 |
+
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
| 62 |
+
) -> list[dict]:
|
| 63 |
+
return [
|
| 64 |
+
{
|
| 65 |
+
"role": "user",
|
| 66 |
+
"content": [
|
| 67 |
+
{"type": "image", "image": image_path},
|
| 68 |
+
{"type": "text", "text": f"{system_prompt}\n\n{prompt}"},
|
| 69 |
+
],
|
| 70 |
+
}
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def build_quantization_config() -> BitsAndBytesConfig:
|
| 75 |
+
return BitsAndBytesConfig(
|
| 76 |
+
load_in_4bit=True,
|
| 77 |
+
bnb_4bit_quant_type="nf4",
|
| 78 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 79 |
+
bnb_4bit_use_double_quant=True,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def resolve_quantized_device_map():
|
| 84 |
+
if not torch.cuda.is_available():
|
| 85 |
+
raise RuntimeError("INT4 quantized inference requires a CUDA GPU.")
|
| 86 |
+
return {"": f"cuda:{torch.cuda.current_device()}"}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class StopOnTokenSequence(StoppingCriteria):
|
| 90 |
+
def __init__(self, stop_ids: list[int]):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.stop_ids = stop_ids
|
| 93 |
+
self.stop_length = len(stop_ids)
|
| 94 |
+
|
| 95 |
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
| 96 |
+
if self.stop_length == 0 or input_ids.shape[1] < self.stop_length:
|
| 97 |
+
return False
|
| 98 |
+
return input_ids[0, -self.stop_length :].tolist() == self.stop_ids
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class ExpertBlock(nn.Module):
|
| 102 |
+
def __init__(self, hidden_dim, bottleneck_dim=64):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.net = nn.Sequential(
|
| 105 |
+
nn.Linear(hidden_dim, bottleneck_dim),
|
| 106 |
+
nn.ReLU(),
|
| 107 |
+
nn.Linear(bottleneck_dim, hidden_dim),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
return self.net(x)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class SkinAwareMoEAdapter(nn.Module):
|
| 115 |
+
def __init__(self, hidden_dim, num_experts=8, top_k=2, bottleneck_dim=64):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.num_experts = num_experts
|
| 118 |
+
self.top_k = top_k
|
| 119 |
+
self.router_img = nn.Linear(hidden_dim, num_experts, bias=False)
|
| 120 |
+
self.router_skin = nn.Linear(3, num_experts, bias=False)
|
| 121 |
+
self.experts = nn.ModuleList(
|
| 122 |
+
[ExpertBlock(hidden_dim, bottleneck_dim) for _ in range(num_experts)]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def forward(self, x: torch.Tensor, skin_probs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 126 |
+
img_logits = self.router_img(x)
|
| 127 |
+
skin_bias = self.router_skin(skin_probs)
|
| 128 |
+
router_logits = img_logits + skin_bias
|
| 129 |
+
router_probs = F.softmax(router_logits, dim=-1)
|
| 130 |
+
|
| 131 |
+
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
| 132 |
+
top_k_probs = top_k_probs / (top_k_probs.sum(dim=-1, keepdim=True) + 1e-6)
|
| 133 |
+
|
| 134 |
+
final_output = torch.zeros_like(x)
|
| 135 |
+
for expert_idx, expert in enumerate(self.experts):
|
| 136 |
+
expert_mask = top_k_indices == expert_idx
|
| 137 |
+
if expert_mask.any():
|
| 138 |
+
rows, k_indices = torch.where(expert_mask)
|
| 139 |
+
inp = x[rows]
|
| 140 |
+
out = expert(inp)
|
| 141 |
+
weights = top_k_probs[rows, k_indices].unsqueeze(-1)
|
| 142 |
+
final_output.index_add_(0, rows, (out * weights).to(final_output.dtype))
|
| 143 |
+
|
| 144 |
+
mean_prob = router_probs.mean(0)
|
| 145 |
+
mask_all = torch.zeros_like(router_probs)
|
| 146 |
+
mask_all.scatter_(1, top_k_indices, 1.0)
|
| 147 |
+
mean_freq = mask_all.mean(0)
|
| 148 |
+
aux_loss = (mean_prob * mean_freq).sum() * self.num_experts
|
| 149 |
+
|
| 150 |
+
return x + final_output, aux_loss
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class PatchDistillHead(nn.Module):
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
embed_dim: int = 1024,
|
| 157 |
+
adapter_layers: int = 4,
|
| 158 |
+
in_dim: Optional[int] = None,
|
| 159 |
+
out_dim: Optional[int] = None,
|
| 160 |
+
num_experts: int = 8,
|
| 161 |
+
top_k: int = 2,
|
| 162 |
+
):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.embed_dim = embed_dim
|
| 165 |
+
self.in_proj = None if in_dim is None else nn.Linear(in_dim, embed_dim, bias=False)
|
| 166 |
+
self.skin_classifier = nn.Sequential(
|
| 167 |
+
nn.Linear(embed_dim, 64),
|
| 168 |
+
nn.ReLU(),
|
| 169 |
+
nn.Linear(64, 3),
|
| 170 |
+
)
|
| 171 |
+
self.adapters = nn.ModuleList(
|
| 172 |
+
[
|
| 173 |
+
SkinAwareMoEAdapter(embed_dim, num_experts=num_experts, top_k=top_k)
|
| 174 |
+
for _ in range(adapter_layers)
|
| 175 |
+
]
|
| 176 |
+
)
|
| 177 |
+
self.out_proj: nn.Module = (
|
| 178 |
+
nn.Identity() if out_dim is None else nn.Linear(embed_dim, out_dim)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def _ensure_in_proj(self, din: int, device, dtype):
|
| 182 |
+
if self.in_proj is None:
|
| 183 |
+
self.in_proj = nn.Linear(din, self.embed_dim, bias=False).to(device=device, dtype=dtype)
|
| 184 |
+
|
| 185 |
+
def forward(self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor) -> dict:
|
| 186 |
+
_, din = pixel_values.shape
|
| 187 |
+
counts = (image_grid_thw[:, 0] * image_grid_thw[:, 1] * image_grid_thw[:, 2]).tolist()
|
| 188 |
+
device, dtype = pixel_values.device, pixel_values.dtype
|
| 189 |
+
self._ensure_in_proj(din, device, dtype)
|
| 190 |
+
chunks = torch.split(pixel_values, counts, dim=0)
|
| 191 |
+
|
| 192 |
+
pooled, all_skin_logits = [], []
|
| 193 |
+
total_aux_loss = torch.tensor(0.0, device=device, dtype=dtype)
|
| 194 |
+
|
| 195 |
+
for x in chunks:
|
| 196 |
+
h = self.in_proj(x)
|
| 197 |
+
global_feat = h.mean(dim=0, keepdim=True)
|
| 198 |
+
skin_logits = self.skin_classifier(global_feat)
|
| 199 |
+
skin_probs = F.softmax(skin_logits, dim=-1)
|
| 200 |
+
all_skin_logits.append(skin_logits)
|
| 201 |
+
skin_probs_expanded = skin_probs.expand(h.size(0), -1)
|
| 202 |
+
|
| 203 |
+
for adapter in self.adapters:
|
| 204 |
+
h, layer_loss = adapter(h, skin_probs_expanded)
|
| 205 |
+
total_aux_loss += layer_loss
|
| 206 |
+
pooled.append(h.mean(dim=0))
|
| 207 |
+
|
| 208 |
+
vision_embed = torch.stack(pooled, dim=0)
|
| 209 |
+
vision_proj = self.out_proj(vision_embed)
|
| 210 |
+
return {
|
| 211 |
+
"vision_embed": vision_embed,
|
| 212 |
+
"vision_proj": vision_proj,
|
| 213 |
+
"aux_loss": total_aux_loss,
|
| 214 |
+
"skin_logits": torch.cat(all_skin_logits, dim=0),
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
def configure_out_dim(self, out_dim: int):
|
| 218 |
+
if isinstance(self.out_proj, nn.Linear) and self.out_proj.out_features == out_dim:
|
| 219 |
+
return
|
| 220 |
+
self.out_proj = (
|
| 221 |
+
nn.Linear(self.embed_dim, out_dim, bias=False)
|
| 222 |
+
if out_dim != self.embed_dim
|
| 223 |
+
else nn.Identity()
|
| 224 |
+
)
|
| 225 |
+
try:
|
| 226 |
+
params = next(self.parameters())
|
| 227 |
+
self.out_proj.to(device=params.device, dtype=params.dtype)
|
| 228 |
+
except StopIteration:
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class SkinVLModelWithAdapter(Qwen2_5_VLForConditionalGeneration):
|
| 233 |
+
def __init__(self, config):
|
| 234 |
+
super().__init__(config)
|
| 235 |
+
self.distill_head = PatchDistillHead(
|
| 236 |
+
embed_dim=1024,
|
| 237 |
+
adapter_layers=4,
|
| 238 |
+
num_experts=8,
|
| 239 |
+
top_k=2,
|
| 240 |
+
in_dim=1176,
|
| 241 |
+
)
|
| 242 |
+
bottleneck = 64
|
| 243 |
+
self.text_bias = nn.Sequential(
|
| 244 |
+
nn.Linear(1024, bottleneck, bias=False),
|
| 245 |
+
nn.Tanh(),
|
| 246 |
+
nn.Linear(bottleneck, config.hidden_size, bias=False),
|
| 247 |
+
)
|
| 248 |
+
self.logit_bias_scale = nn.Parameter(torch.tensor(2.5, dtype=torch.bfloat16))
|
| 249 |
+
|
| 250 |
+
def forward(self, *args, **kwargs):
|
| 251 |
+
skin_vocab_mask = kwargs.pop("skin_vocab_mask", None)
|
| 252 |
+
skin_labels = kwargs.get("skin_labels", None)
|
| 253 |
+
pixel_values = kwargs.get("pixel_values", None)
|
| 254 |
+
image_grid_thw = kwargs.get("image_grid_thw", None)
|
| 255 |
+
|
| 256 |
+
if isinstance(pixel_values, list):
|
| 257 |
+
try:
|
| 258 |
+
pixel_values = torch.stack(pixel_values)
|
| 259 |
+
kwargs["pixel_values"] = pixel_values
|
| 260 |
+
except Exception:
|
| 261 |
+
pass
|
| 262 |
+
|
| 263 |
+
outputs = super().forward(*args, **kwargs)
|
| 264 |
+
|
| 265 |
+
vision_embed = None
|
| 266 |
+
loss_skin = torch.tensor(0.0, device=outputs.logits.device)
|
| 267 |
+
aux_loss = torch.tensor(0.0, device=outputs.logits.device)
|
| 268 |
+
|
| 269 |
+
if pixel_values is not None and image_grid_thw is not None:
|
| 270 |
+
if not isinstance(pixel_values, torch.Tensor):
|
| 271 |
+
if isinstance(pixel_values, list):
|
| 272 |
+
pixel_values = torch.stack(pixel_values)
|
| 273 |
+
else:
|
| 274 |
+
pixel_values = torch.tensor(pixel_values)
|
| 275 |
+
|
| 276 |
+
image_grid_thw = image_grid_thw.to(pixel_values.device)
|
| 277 |
+
side = self.distill_head(pixel_values=pixel_values, image_grid_thw=image_grid_thw)
|
| 278 |
+
vision_embed = side["vision_embed"]
|
| 279 |
+
aux_loss = side["aux_loss"]
|
| 280 |
+
|
| 281 |
+
if skin_labels is not None:
|
| 282 |
+
skin_labels = skin_labels.to(side["skin_logits"].device)
|
| 283 |
+
loss_skin = nn.CrossEntropyLoss()(side["skin_logits"], skin_labels)
|
| 284 |
+
|
| 285 |
+
setattr(outputs, "vision_embed", vision_embed)
|
| 286 |
+
setattr(outputs, "vision_proj", side["vision_proj"])
|
| 287 |
+
setattr(outputs, "loss_skin", loss_skin)
|
| 288 |
+
setattr(outputs, "aux_loss", aux_loss)
|
| 289 |
+
setattr(outputs, "skin_logits", side["skin_logits"])
|
| 290 |
+
|
| 291 |
+
pack_vision_proj = (
|
| 292 |
+
side["vision_proj"]
|
| 293 |
+
if side["vision_proj"] is not None
|
| 294 |
+
else torch.tensor(0.0, device=aux_loss.device)
|
| 295 |
+
)
|
| 296 |
+
pack_skin_logits = (
|
| 297 |
+
side["skin_logits"]
|
| 298 |
+
if side["skin_logits"] is not None
|
| 299 |
+
else torch.tensor(0.0, device=aux_loss.device)
|
| 300 |
+
)
|
| 301 |
+
outputs.attentions = (pack_vision_proj, aux_loss, pack_skin_logits)
|
| 302 |
+
|
| 303 |
+
self.latest_side_output = {
|
| 304 |
+
"vision_proj": side["vision_proj"],
|
| 305 |
+
"aux_loss": aux_loss,
|
| 306 |
+
"skin_logits": side["skin_logits"],
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
if hasattr(outputs, "logits") and vision_embed is not None and skin_vocab_mask is not None:
|
| 310 |
+
bias_features = self.text_bias(vision_embed.to(self.logit_bias_scale.dtype))
|
| 311 |
+
lm_weight = self.lm_head.weight.to(bias_features.dtype)
|
| 312 |
+
vocab_bias = F.linear(bias_features, lm_weight)
|
| 313 |
+
scale = self.logit_bias_scale.to(outputs.logits.dtype)
|
| 314 |
+
outputs.logits = outputs.logits + (scale * vocab_bias[:, None, :] * skin_vocab_mask)
|
| 315 |
+
|
| 316 |
+
if outputs.loss is not None:
|
| 317 |
+
outputs.loss = outputs.loss + loss_skin + (0.01 * aux_loss)
|
| 318 |
+
|
| 319 |
+
return outputs
|
| 320 |
+
|
| 321 |
+
def freeze_all_but_distill(self):
|
| 322 |
+
self.requires_grad_(False)
|
| 323 |
+
for params in self.distill_head.parameters():
|
| 324 |
+
params.requires_grad_(True)
|
| 325 |
+
for params in self.text_bias.parameters():
|
| 326 |
+
params.requires_grad_(True)
|
| 327 |
+
self.logit_bias_scale.requires_grad_(True)
|
| 328 |
+
|
| 329 |
+
def configure_out_dim(self, out_dim: int):
|
| 330 |
+
self.distill_head.configure_out_dim(out_dim)
|
| 331 |
+
|
| 332 |
+
def project_only(self, vision_embed: torch.Tensor) -> torch.Tensor:
|
| 333 |
+
return self.distill_head.out_proj(vision_embed)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def load_quantized_model_and_processor(model_path: str = DEFAULT_MODEL_PATH):
|
| 337 |
+
resolved_model_path = resolve_model_path(model_path)
|
| 338 |
+
quantization_config = build_quantization_config()
|
| 339 |
+
model = SkinVLModelWithAdapter.from_pretrained(
|
| 340 |
+
resolved_model_path,
|
| 341 |
+
device_map=resolve_quantized_device_map(),
|
| 342 |
+
quantization_config=quantization_config,
|
| 343 |
+
attn_implementation="sdpa",
|
| 344 |
+
)
|
| 345 |
+
model.eval()
|
| 346 |
+
processor = AutoProcessor.from_pretrained(
|
| 347 |
+
resolved_model_path,
|
| 348 |
+
min_pixels=256 * 28 * 28,
|
| 349 |
+
max_pixels=1280 * 28 * 28,
|
| 350 |
+
)
|
| 351 |
+
return model, processor
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def get_model_device(model) -> torch.device:
|
| 355 |
+
try:
|
| 356 |
+
return model.device
|
| 357 |
+
except AttributeError:
|
| 358 |
+
return next(model.parameters()).device
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def prepare_inputs(processor, model, messages: list[dict]):
|
| 362 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 363 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 364 |
+
inputs = processor(
|
| 365 |
+
text=[text],
|
| 366 |
+
images=image_inputs,
|
| 367 |
+
videos=video_inputs,
|
| 368 |
+
padding=True,
|
| 369 |
+
return_tensors="pt",
|
| 370 |
+
).to(get_model_device(model))
|
| 371 |
+
inputs.pop("mm_token_type_ids", None)
|
| 372 |
+
return inputs
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class QuantizedSkinGPTModel:
|
| 376 |
+
def __init__(self, model_path: str = DEFAULT_MODEL_PATH):
|
| 377 |
+
resolved_model_path = resolve_model_path(model_path)
|
| 378 |
+
print(f"Loading INT4 model from {resolved_model_path}...")
|
| 379 |
+
self.model, self.processor = load_quantized_model_and_processor(resolved_model_path)
|
| 380 |
+
self.model_path = resolved_model_path
|
| 381 |
+
self.device = get_model_device(self.model)
|
| 382 |
+
self.stop_ids = self.processor.tokenizer.encode("</answer>", add_special_tokens=False)
|
| 383 |
+
print(f"Model loaded successfully on {self.device}.")
|
| 384 |
+
|
| 385 |
+
@staticmethod
|
| 386 |
+
def has_complete_answer(text: str) -> bool:
|
| 387 |
+
return "<answer>" in text and "</answer>" in text
|
| 388 |
+
|
| 389 |
+
def _build_generation_kwargs(
|
| 390 |
+
self,
|
| 391 |
+
inputs,
|
| 392 |
+
max_new_tokens: int,
|
| 393 |
+
do_sample: bool,
|
| 394 |
+
temperature: float,
|
| 395 |
+
repetition_penalty: float,
|
| 396 |
+
top_p: float,
|
| 397 |
+
no_repeat_ngram_size: int,
|
| 398 |
+
streamer=None,
|
| 399 |
+
) -> dict:
|
| 400 |
+
generation_kwargs = {
|
| 401 |
+
**inputs,
|
| 402 |
+
"max_new_tokens": max_new_tokens,
|
| 403 |
+
"do_sample": do_sample,
|
| 404 |
+
"repetition_penalty": repetition_penalty,
|
| 405 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
| 406 |
+
"use_cache": True,
|
| 407 |
+
"stopping_criteria": StoppingCriteriaList([StopOnTokenSequence(self.stop_ids)]),
|
| 408 |
+
}
|
| 409 |
+
if streamer is not None:
|
| 410 |
+
generation_kwargs["streamer"] = streamer
|
| 411 |
+
if do_sample:
|
| 412 |
+
generation_kwargs["temperature"] = temperature
|
| 413 |
+
generation_kwargs["top_p"] = top_p
|
| 414 |
+
return generation_kwargs
|
| 415 |
+
|
| 416 |
+
def _generate_text(
|
| 417 |
+
self,
|
| 418 |
+
messages,
|
| 419 |
+
max_new_tokens: int,
|
| 420 |
+
do_sample: bool,
|
| 421 |
+
temperature: float,
|
| 422 |
+
repetition_penalty: float,
|
| 423 |
+
top_p: float,
|
| 424 |
+
no_repeat_ngram_size: int,
|
| 425 |
+
) -> str:
|
| 426 |
+
inputs = prepare_inputs(self.processor, self.model, messages)
|
| 427 |
+
generation_kwargs = self._build_generation_kwargs(
|
| 428 |
+
inputs=inputs,
|
| 429 |
+
max_new_tokens=max_new_tokens,
|
| 430 |
+
do_sample=do_sample,
|
| 431 |
+
temperature=temperature,
|
| 432 |
+
repetition_penalty=repetition_penalty,
|
| 433 |
+
top_p=top_p,
|
| 434 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
with torch.inference_mode():
|
| 438 |
+
generated_ids = self.model.generate(**generation_kwargs)
|
| 439 |
+
|
| 440 |
+
generated_ids_trimmed = [
|
| 441 |
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 442 |
+
]
|
| 443 |
+
output_text = self.processor.batch_decode(
|
| 444 |
+
generated_ids_trimmed,
|
| 445 |
+
skip_special_tokens=True,
|
| 446 |
+
clean_up_tokenization_spaces=False,
|
| 447 |
+
)
|
| 448 |
+
return output_text[0]
|
| 449 |
+
|
| 450 |
+
def generate_response(
|
| 451 |
+
self,
|
| 452 |
+
messages,
|
| 453 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
| 454 |
+
continue_tokens: int = DEFAULT_CONTINUE_TOKENS,
|
| 455 |
+
do_sample: bool = DEFAULT_DO_SAMPLE,
|
| 456 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
| 457 |
+
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
|
| 458 |
+
top_p: float = DEFAULT_TOP_P,
|
| 459 |
+
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
| 460 |
+
) -> str:
|
| 461 |
+
output_text = self._generate_text(
|
| 462 |
+
messages=messages,
|
| 463 |
+
max_new_tokens=max_new_tokens,
|
| 464 |
+
do_sample=do_sample,
|
| 465 |
+
temperature=temperature,
|
| 466 |
+
repetition_penalty=repetition_penalty,
|
| 467 |
+
top_p=top_p,
|
| 468 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 469 |
+
)
|
| 470 |
+
if not self.has_complete_answer(output_text) and continue_tokens > 0:
|
| 471 |
+
output_text = self._generate_text(
|
| 472 |
+
messages=messages,
|
| 473 |
+
max_new_tokens=max_new_tokens + continue_tokens,
|
| 474 |
+
do_sample=do_sample,
|
| 475 |
+
temperature=temperature,
|
| 476 |
+
repetition_penalty=repetition_penalty,
|
| 477 |
+
top_p=top_p,
|
| 478 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 479 |
+
)
|
| 480 |
+
return output_text
|
| 481 |
+
|
| 482 |
+
def generate_response_stream(
|
| 483 |
+
self,
|
| 484 |
+
messages,
|
| 485 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
| 486 |
+
continue_tokens: int = DEFAULT_CONTINUE_TOKENS,
|
| 487 |
+
do_sample: bool = DEFAULT_DO_SAMPLE,
|
| 488 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
| 489 |
+
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
|
| 490 |
+
top_p: float = DEFAULT_TOP_P,
|
| 491 |
+
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
| 492 |
+
):
|
| 493 |
+
inputs = prepare_inputs(self.processor, self.model, messages)
|
| 494 |
+
streamer = TextIteratorStreamer(
|
| 495 |
+
self.processor.tokenizer,
|
| 496 |
+
skip_prompt=True,
|
| 497 |
+
skip_special_tokens=True,
|
| 498 |
+
)
|
| 499 |
+
generation_kwargs = self._build_generation_kwargs(
|
| 500 |
+
inputs=inputs,
|
| 501 |
+
max_new_tokens=max_new_tokens,
|
| 502 |
+
do_sample=do_sample,
|
| 503 |
+
temperature=temperature,
|
| 504 |
+
repetition_penalty=repetition_penalty,
|
| 505 |
+
top_p=top_p,
|
| 506 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 507 |
+
streamer=streamer,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
def _generate():
|
| 511 |
+
with torch.inference_mode():
|
| 512 |
+
self.model.generate(**generation_kwargs)
|
| 513 |
+
|
| 514 |
+
thread = Thread(target=_generate)
|
| 515 |
+
thread.start()
|
| 516 |
+
|
| 517 |
+
partial_chunks = []
|
| 518 |
+
for text_chunk in streamer:
|
| 519 |
+
partial_chunks.append(text_chunk)
|
| 520 |
+
yield text_chunk
|
| 521 |
+
|
| 522 |
+
thread.join()
|
| 523 |
+
|
| 524 |
+
partial_text = "".join(partial_chunks)
|
| 525 |
+
if not self.has_complete_answer(partial_text) and continue_tokens > 0:
|
| 526 |
+
completed_text = self._generate_text(
|
| 527 |
+
messages=messages,
|
| 528 |
+
max_new_tokens=max_new_tokens + continue_tokens,
|
| 529 |
+
do_sample=do_sample,
|
| 530 |
+
temperature=temperature,
|
| 531 |
+
repetition_penalty=repetition_penalty,
|
| 532 |
+
top_p=top_p,
|
| 533 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 534 |
+
)
|
| 535 |
+
if completed_text.startswith(partial_text):
|
| 536 |
+
tail_text = completed_text[len(partial_text) :]
|
| 537 |
+
if tail_text:
|
| 538 |
+
yield tail_text
|
inference/int4_quantized/run_api.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
PYTHON_EXE="${PYTHON_EXE:-python}"
|
| 5 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 6 |
+
"${PYTHON_EXE}" "${SCRIPT_DIR}/app.py"
|
inference/int4_quantized/run_chat.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
PYTHON_EXE="${PYTHON_EXE:-python}"
|
| 5 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 6 |
+
"${PYTHON_EXE}" "${SCRIPT_DIR}/chat.py" "$@"
|
inference/int4_quantized/run_infer.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
PYTHON_EXE="${PYTHON_EXE:-python}"
|
| 5 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 6 |
+
"${PYTHON_EXE}" "${SCRIPT_DIR}/infer.py" "$@"
|
inference/int4_quantized/test_single.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 5 |
+
|
| 6 |
+
"${SCRIPT_DIR}/run_infer.sh" "$@"
|
inference/temp_uploads/.ipynb_checkpoints/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef-checkpoint.jpg
DELETED
|
Binary file (47.9 kB)
|
|
|
inference/temp_uploads/.ipynb_checkpoints/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c-checkpoint.jpg
DELETED
|
Binary file (80.7 kB)
|
|
|
inference/temp_uploads/temp_d2b1c6f9a43940d2812f10a8cc8bc3ef.jpg
DELETED
|
Binary file (47.9 kB)
|
|
|
inference/temp_uploads/user_1769671453128_43ccc61bfcb64c6bbbabbadfa887591c.jpg
DELETED
|
Binary file (80.7 kB)
|
|
|
requirements.txt
CHANGED
|
@@ -14,6 +14,10 @@ fastapi>=0.100.0
|
|
| 14 |
uvicorn>=0.20.0
|
| 15 |
python-multipart>=0.0.6
|
| 16 |
openai>=1.0.0 # For DeepSeek API (OpenAI-compatible)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Install latest transformers from source (Required for Qwen2.5-VL/Vision-R1)
|
| 19 |
git+https://github.com/huggingface/transformers.git
|
|
@@ -23,4 +27,4 @@ git+https://github.com/huggingface/transformers.git
|
|
| 23 |
|
| 24 |
# For potential future demo usage
|
| 25 |
gradio==5.4.0
|
| 26 |
-
gradio_client==1.4.2
|
|
|
|
| 14 |
uvicorn>=0.20.0
|
| 15 |
python-multipart>=0.0.6
|
| 16 |
openai>=1.0.0 # For DeepSeek API (OpenAI-compatible)
|
| 17 |
+
bitsandbytes>=0.43.0 # Required for INT4 quantized inference
|
| 18 |
+
# Attention notes:
|
| 19 |
+
# - SDPA is built into PyTorch 2.x
|
| 20 |
+
# - flash-attn is optional and mainly useful on GPUs officially supported by the project
|
| 21 |
|
| 22 |
# Install latest transformers from source (Required for Qwen2.5-VL/Vision-R1)
|
| 23 |
git+https://github.com/huggingface/transformers.git
|
|
|
|
| 27 |
|
| 28 |
# For potential future demo usage
|
| 29 |
gradio==5.4.0
|
| 30 |
+
gradio_client==1.4.2
|