Klaus-Chow's picture
Update app.py
09bc434 verified
import os
import gradio as gr
import torch
import time
from swift.llm import (
get_model_tokenizer, get_template, inference,
get_default_template_type
)
from swift.utils import seed_everything
# ========== 模型配置 ==========
MODEL_TYPE = os.environ.get("MODEL_TYPE", "internvl2-2b")
MODEL_PATH = os.environ.get("MODEL_PATH", "./model")
print(f"Model type: {MODEL_TYPE}")
print(f"Model path: {MODEL_PATH}")
template_type = get_default_template_type(MODEL_TYPE)
print(f"Template type: {template_type}")
# ========== 加载模型 ==========
model, tokenizer = get_model_tokenizer(
model_type=MODEL_TYPE,
model_id_or_path=MODEL_PATH,
torch_dtype=torch.bfloat16,
)
model.generation_config.max_new_tokens = 256
template = get_template(template_type, tokenizer)
seed_everything(42)
print("=" * 50)
print("Model loaded successfully!")
print("=" * 50)
# ========== 推理函数 ==========
def predict(image_url: str, query: str = "这张图片中有任何人为编辑的迹象吗?"):
"""
图片编辑检测接口
"""
if not image_url or not image_url.startswith("http"):
return "错误:请提供有效的图片URL(以http/https开头)"
# 构建消息(使用官方推荐的messages格式)
messages = [{
"role": "user",
"content": f"<img>{image_url}</img>{query}"
}]
# 推理计时
start = time.time()
resp_list = inference(model, template, messages)
elapsed = time.time() - start
response = resp_list[0]["response"]
# 返回带计时信息的结果
return f"{response}\n\n【推理用时: {elapsed:.2f}秒】"
# ========== Gradio界面 ==========
with gr.Blocks(title="图像编辑检测 - Tiger Vision") as demo:
gr.Markdown("""
# 🐯 Tiger Vision
**基于 InternVL2 的图像篡改检测服务**
""")
with gr.Row():
with gr.Column():
image_input = gr.Textbox(
label="图片URL",
placeholder="https://example.com/image.jpg",
lines=2,
)
question_input = gr.Textbox(
label="检测问题(可选)",
value="这张图片中有任何人为编辑的迹象吗?",
lines=2,
)
submit_btn = gr.Button("🔍 开始检测", variant="primary")
with gr.Column():
output = gr.Textbox(
label="检测结果",
lines=12,
show_copy_button=True,
)
gr.Examples(
examples=[["https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/road.png",
"这张图片中有任何人为编辑的迹象吗?"]],
inputs=[image_input, question_input],
)
submit_btn.click(fn=predict, inputs=[image_input, question_input], outputs=output)
image_input.submit(fn=predict, inputs=[image_input, question_input], outputs=output)
# ========== 启动(Space标准配置)==========
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)