|
|
| import json
|
| from pathlib import Path
|
| from flask import request, send_file, current_app, make_response
|
| from flask_restful import Resource
|
| from flask_jwt_extended import jwt_required, get_jwt_identity
|
| from datetime import datetime
|
| from io import BytesIO
|
| import zipfile
|
| import os
|
|
|
| from app import db, Setting
|
| from app.models import Customer
|
| from app.models.translate import Translate
|
| from app.resources.task.translate_service import TranslateEngine
|
| from app.utils.response import APIResponse
|
| from app.utils.check_utils import AIChecker
|
|
|
|
|
| TRANSLATE_SETTINGS = {
|
| "models": ["gpt-3.5-turbo", "gpt-4"],
|
| "default_model": "gpt-3.5-turbo",
|
| "max_threads": 5,
|
| "prompt_template": "请将以下内容翻译为{target_lang}"
|
| }
|
|
|
|
|
| class TranslateStartResource1(Resource):
|
| @jwt_required()
|
| def post(self):
|
| """启动翻译任务(支持绝对路径和多参数)[^1]"""
|
| data = request.form
|
| required_fields = [
|
| 'server', 'model', 'lang', 'uuid',
|
| 'prompt', 'threads', 'file_name'
|
| ]
|
|
|
|
|
| if not all(field in data for field in required_fields):
|
| return APIResponse.error("缺少必要参数", 400)
|
|
|
|
|
| if data['server'] == 'openai' and not all(k in data for k in ['api_url', 'api_key']):
|
| return APIResponse.error("OpenAI服务需要API地址和密钥", 400)
|
|
|
| try:
|
|
|
| user_id = get_jwt_identity()
|
| customer = Customer.query.get(user_id)
|
|
|
|
|
| def get_absolute_storage_path(filename):
|
|
|
| base_dir = Path(current_app.root_path).parent.absolute()
|
|
|
| date_str = datetime.now().strftime('%Y-%m-%d')
|
|
|
| target_dir = base_dir / "storage" / "translate" / date_str
|
| target_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| return str(target_dir / filename)
|
|
|
| origin_filename = data['file_name']
|
|
|
|
|
| target_abs_path = get_absolute_storage_path(origin_filename)
|
|
|
|
|
| translate_type = data.get('type[2]', 'trans_all_only_inherit')
|
|
|
|
|
| translate = Translate.query.filter_by(uuid=data['uuid']).first()
|
| if not translate:
|
| return APIResponse.error("未找到对应的翻译记录", 404)
|
|
|
|
|
| translate.origin_filename = data['file_name']
|
| translate.target_filepath = target_abs_path
|
| translate.lang = data['lang']
|
| translate.model = data['model']
|
| translate.backup_model = data['backup_model']
|
| translate.type = translate_type
|
| translate.prompt = data['prompt']
|
| translate.threads = int(data['threads'])
|
| translate.api_url = data.get('api_url', '')
|
| translate.api_key = data.get('api_key', '')
|
| translate.backup_model = data.get('backup_model', '')
|
| translate.origin_lang = data.get('origin_lang', '')
|
|
|
| comparison_id = data.get('comparison_id', '0')
|
| translate.comparison_id = int(comparison_id) if comparison_id else None
|
| prompt_id = data.get('prompt_id', '0')
|
| translate.prompt_id = int(prompt_id) if prompt_id else None
|
| translate.doc2x_flag = data.get('doc2x_flag', 'N')
|
| translate.doc2x_secret_key = data.get('doc2x_secret_key', '')
|
|
|
|
|
| db.session.commit()
|
|
|
|
|
| TranslateEngine(translate.id).execute()
|
|
|
| return APIResponse.success({
|
| "task_id": translate.id,
|
| "uuid": translate.uuid,
|
| "target_path": target_abs_path
|
| })
|
|
|
| except Exception as e:
|
| db.session.rollback()
|
| current_app.logger.error(f"翻译任务启动失败: {str(e)}", exc_info=True)
|
| return APIResponse.error("任务启动失败", 500)
|
|
|
|
|
|
|
| class TranslateStartResource(Resource):
|
| @jwt_required()
|
| def post(self):
|
| """启动翻译任务(支持绝对路径和多参数)[^1]"""
|
| data = request.form
|
| required_fields = [
|
| 'server', 'model', 'lang', 'uuid',
|
| 'prompt', 'threads', 'file_name'
|
| ]
|
|
|
|
|
| if not all(field in data for field in required_fields):
|
| return APIResponse.error("缺少必要参数", 400)
|
|
|
|
|
| if data['server'] == 'openai' and not all(k in data for k in ['api_url', 'api_key']):
|
| return APIResponse.error("OpenAI服务需要API地址和密钥", 400)
|
|
|
| try:
|
|
|
| user_id = get_jwt_identity()
|
| customer = Customer.query.get(user_id)
|
|
|
|
|
| def get_absolute_storage_path(filename):
|
|
|
| base_dir = Path(current_app.root_path).parent.absolute()
|
|
|
| date_str = datetime.now().strftime('%Y-%m-%d')
|
|
|
| target_dir = base_dir / "storage" / "translate" / date_str
|
| target_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| return target_dir / filename
|
|
|
| origin_filename = data['file_name']
|
|
|
|
|
| target_abs_path = get_absolute_storage_path(origin_filename)
|
|
|
|
|
| translate_type = data.get('type[2]', 'trans_all_only_inherit')
|
|
|
|
|
| translate = Translate.query.filter_by(uuid=data['uuid']).first()
|
| if not translate:
|
| return APIResponse.error("未找到对应的翻译记录", 404)
|
|
|
|
|
| translate.origin_filename = origin_filename
|
| translate.target_filepath = str(target_abs_path)
|
| translate.lang = data['lang']
|
| translate.model = data['model']
|
| translate.backup_model = data['backup_model']
|
| translate.type = translate_type
|
| translate.prompt = data['prompt']
|
| translate.threads = int(data['threads'])
|
| translate.api_url = data.get('api_url', '')
|
| translate.api_key = data.get('api_key', '')
|
| translate.backup_model = data.get('backup_model', '')
|
| translate.origin_lang = data.get('origin_lang', '')
|
|
|
| comparison_id = data.get('comparison_id', '0')
|
| translate.comparison_id = int(comparison_id) if comparison_id else None
|
| prompt_id = data.get('prompt_id', '0')
|
| translate.prompt_id = int(prompt_id) if prompt_id else None
|
| translate.doc2x_flag = data.get('doc2x_flag', 'N')
|
| translate.doc2x_secret_key = data.get('doc2x_secret_key', '')
|
|
|
|
|
| db.session.commit()
|
|
|
| TranslateEngine(translate.id).execute()
|
|
|
| return APIResponse.success({
|
| "task_id": translate.id,
|
| "uuid": translate.uuid,
|
| "target_path": str(target_abs_path)
|
| })
|
|
|
| except Exception as e:
|
| db.session.rollback()
|
| current_app.logger.error(f"翻译任务启动失败: {str(e)}", exc_info=True)
|
| return APIResponse.error("任务启动失败", 500)
|
|
|
|
|
|
|
| class TranslateListResource(Resource):
|
| @jwt_required()
|
| def get(self):
|
| """获取翻译记录列表"""
|
|
|
| page = request.args.get('page', '1')
|
| limit = request.args.get('limit', '100')
|
| status_filter = request.args.get('status')
|
|
|
| try:
|
| page = int(page)
|
| limit = int(limit)
|
| except ValueError:
|
| return APIResponse.error("Invalid page or limit value"), 400
|
|
|
| query = Translate.query.filter_by(
|
| customer_id=get_jwt_identity(),
|
| deleted_flag='N'
|
| )
|
|
|
|
|
| if status_filter:
|
| valid_statuses = {'none', 'process', 'done', 'failed'}
|
| if status_filter not in valid_statuses:
|
| return APIResponse.error(f"Invalid status value: {status_filter}"), 400
|
| query = query.filter_by(status=status_filter)
|
|
|
|
|
| pagination = query.paginate(page=page, per_page=limit, error_out=False)
|
|
|
|
|
| data = []
|
| for t in pagination.items:
|
|
|
| if t.created_at and t.end_at:
|
| spend_time = t.end_at - t.created_at
|
| spend_time_minutes = int(spend_time.total_seconds() // 60)
|
| spend_time_seconds = int(spend_time.total_seconds() % 60)
|
| spend_time_str = f"{spend_time_minutes}分{spend_time_seconds}秒"
|
| else:
|
| spend_time_str = "--"
|
|
|
|
|
| status_name_map = {
|
| 'none': '未开始',
|
| 'process': '进行中',
|
| 'done': '已完成',
|
| 'failed': '失败'
|
| }
|
| status_name = status_name_map.get(t.status, '未知状态')
|
|
|
|
|
| file_type = self.get_file_type(t.origin_filename)
|
|
|
|
|
| end_at_str = t.end_at.strftime('%Y-%m-%d %H:%M:%S') if t.end_at else "--"
|
|
|
| data.append({
|
| 'id': t.id,
|
| 'file_type': file_type,
|
| 'origin_filename': t.origin_filename,
|
| 'status': t.status,
|
| 'status_name': status_name,
|
| 'process': float(t.process),
|
| 'spend_time': spend_time_str,
|
| 'end_at': end_at_str,
|
| 'start_at': t.start_at.strftime('%Y-%m-%d %H:%M:%S') if t.start_at else "--",
|
|
|
| 'lang': t.lang,
|
| 'target_filepath': t.target_filepath
|
| })
|
|
|
|
|
| return APIResponse.success({
|
| 'data': data,
|
| 'total': pagination.total,
|
| 'current_page': pagination.page
|
| })
|
|
|
| @staticmethod
|
| def get_file_type(filename):
|
| """根据文件名获取文件类型"""
|
| if not filename:
|
| return "未知"
|
| ext = filename.split('.')[-1].lower()
|
| if ext in {'docx', 'doc'}:
|
| return "Word"
|
| elif ext in {'xlsx', 'xls'}:
|
| return "Excel"
|
| elif ext == 'pptx':
|
| return "PPT"
|
| elif ext == 'pdf':
|
| return "PDF"
|
| elif ext in {'txt', 'md'}:
|
| return "文本"
|
| else:
|
| return "其他"
|
|
|
|
|
| class TranslateSettingResource(Resource):
|
| @jwt_required()
|
| def get(self):
|
| """获取翻译配置"""
|
| try:
|
|
|
| settings = self._load_settings_from_db()
|
| return APIResponse.success(settings)
|
| except Exception as e:
|
| return APIResponse.error(f"获取配置失败: {str(e)}", 500)
|
|
|
| @staticmethod
|
| def _load_settings_from_db():
|
| """
|
| 从数据库加载翻译配置
|
| """
|
|
|
| settings = Setting.query.filter(
|
| Setting.group.in_(['api_setting', 'other_setting']),
|
| Setting.deleted_flag == 'N'
|
| ).all()
|
|
|
|
|
| config = {}
|
| for setting in settings:
|
|
|
| value = json.loads(setting.value) if setting.serialized else setting.value
|
|
|
|
|
| if setting.alias == 'models':
|
| config['models'] = value.split(',') if isinstance(value, str) else value
|
| elif setting.alias == 'default_model':
|
| config['default_model'] = value
|
| elif setting.alias == 'default_backup':
|
| config['default_backup'] = value
|
| elif setting.alias == 'api_url':
|
| config['api_url'] = value
|
| elif setting.alias == 'api_key':
|
| config['api_key'] = value
|
| elif setting.alias == 'prompt':
|
| config['prompt_template'] = value
|
| elif setting.alias == 'threads':
|
| config['max_threads'] = int(value) if value.isdigit() else 10
|
|
|
|
|
| config.setdefault('models', ['gpt-3.5-turbo', 'gpt-4'])
|
| config.setdefault('default_model', 'gpt-3.5-turbo')
|
| config.setdefault('default_backup', 'gpt-3.5-turbo')
|
| config.setdefault('api_url', '')
|
| config.setdefault('api_key', '')
|
| config.setdefault('prompt_template', '请将以下内容翻译为{target_lang}')
|
| config.setdefault('max_threads', 10)
|
|
|
| return config
|
|
|
|
|
| class TranslateProcessResource(Resource):
|
| @jwt_required()
|
| def post(self):
|
| """查询翻译进度[^3]"""
|
| uuid = request.form.get('uuid')
|
| translate = Translate.query.filter_by(
|
| uuid=uuid,
|
| customer_id=get_jwt_identity()
|
| ).first_or_404()
|
|
|
| return APIResponse.success({
|
| 'status': translate.status,
|
| 'progress': float(translate.process),
|
| 'download_url': translate.target_filepath if translate.status == 'done' else None
|
| })
|
|
|
|
|
| class TranslateDeleteResource(Resource):
|
| @jwt_required()
|
| def delete(self, id):
|
| """软删除翻译记录[^4]"""
|
|
|
| translate = Translate.query.filter_by(
|
| id=id,
|
| customer_id=get_jwt_identity()
|
| ).first_or_404()
|
|
|
|
|
| translate.deleted_flag = 'Y'
|
| db.session.commit()
|
|
|
| return APIResponse.success(message='记录已标记为删除')
|
|
|
|
|
|
|
| class TranslateDownloadResource(Resource):
|
|
|
| def get(self, id):
|
| """通过 ID 下载单个翻译结果文件[^5]"""
|
|
|
| translate = Translate.query.filter_by(
|
| id=id,
|
|
|
| ).first_or_404()
|
|
|
|
|
| if not translate.target_filepath or not os.path.exists(translate.target_filepath):
|
| return APIResponse.error('文件不存在', 404)
|
|
|
|
|
| response = make_response(send_file(
|
| translate.target_filepath,
|
| as_attachment=True,
|
| download_name=os.path.basename(translate.target_filepath)
|
| ))
|
|
|
|
|
| response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0'
|
| response.headers['Pragma'] = 'no-cache'
|
| response.headers['Expires'] = '0'
|
|
|
| return response
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TranslateDownloadAllResource(Resource):
|
| @jwt_required()
|
| def get(self):
|
| """批量下载所有翻译结果文件[^6]"""
|
|
|
| records = Translate.query.filter_by(
|
| customer_id=get_jwt_identity(),
|
| deleted_flag='N'
|
| ).all()
|
|
|
|
|
| zip_buffer = BytesIO()
|
| with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
|
| for record in records:
|
| if record.target_filepath and os.path.exists(record.target_filepath):
|
|
|
| zip_file.write(
|
| record.target_filepath,
|
| os.path.basename(record.target_filepath)
|
| )
|
|
|
|
|
| zip_buffer.seek(0)
|
|
|
|
|
| return send_file(
|
| zip_buffer,
|
| mimetype='application/zip',
|
| as_attachment=True,
|
| download_name=f"translations_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
|
| )
|
|
|
|
|
| class OpenAICheckResource(Resource):
|
| @jwt_required()
|
| def post(self):
|
| """OpenAI接口检测[^6]"""
|
| data = request.form
|
| required = ['api_url', 'api_key', 'model']
|
| if not all(k in data for k in required):
|
| return APIResponse.error('缺少必要参数', 400)
|
|
|
| is_valid, msg = AIChecker.check_openai_connection(
|
| data['api_url'],
|
| data['api_key'],
|
| data['model']
|
| )
|
|
|
| return APIResponse.success({'valid': is_valid, 'message': msg})
|
|
|
|
|
| class PDFCheckResource(Resource):
|
| @jwt_required()
|
| def post(self):
|
| """PDF扫描件检测[^7]"""
|
| if 'file' not in request.files:
|
| return APIResponse.error('请选择PDF文件', 400)
|
|
|
| file = request.files['file']
|
| if not file.filename.lower().endswith('.pdf'):
|
| return APIResponse.error('仅支持PDF文件', 400)
|
|
|
| try:
|
| file_stream = file.stream
|
| is_scanned = AIChecker.check_pdf_scanned(file_stream)
|
| return APIResponse.success({'scanned': is_scanned})
|
| except Exception as e:
|
| return APIResponse.error(f'检测失败: {str(e)}', 500)
|
|
|
|
|
|
|
| class TranslateTestResource(Resource):
|
| def get(self):
|
| """测试翻译服务[^1]"""
|
| return APIResponse.success(message="测试服务正常")
|
|
|
|
|
| class TranslateDeleteAllResource(Resource):
|
| @jwt_required()
|
| def delete(self):
|
| """删除用户所有翻译记录[^2]"""
|
| Translate.query.filter_by(
|
| customer_id=get_jwt_identity(),
|
| deleted_flag='N'
|
| ).delete()
|
| db.session.commit()
|
| return APIResponse.success(message="删除成功")
|
|
|
|
|
| class TranslateFinishCountResource(Resource):
|
| @jwt_required()
|
| def get(self):
|
| """获取已完成翻译数量[^3]"""
|
| count = Translate.query.filter_by(
|
| customer_id=get_jwt_identity(),
|
| status='done',
|
| deleted_flag='N'
|
| ).count()
|
| return APIResponse.success({'total': count})
|
|
|
|
|
| class TranslateRandDeleteAllResource(Resource):
|
| def delete(self):
|
| """删除临时用户所有记录[^4]"""
|
| rand_user_id = request.json.get('rand_user_id')
|
| if not rand_user_id:
|
| return APIResponse.error('需要临时用户ID', 400)
|
|
|
| Translate.query.filter_by(
|
| rand_user_id=rand_user_id,
|
| deleted_flag='N'
|
| ).delete()
|
| db.session.commit()
|
| return APIResponse.success(message="删除成功")
|
|
|
|
|
| class TranslateRandDeleteResource(Resource):
|
| def delete(self, id):
|
| """删除临时用户单条记录[^5]"""
|
| rand_user_id = request.json.get('rand_user_id')
|
| translate = Translate.query.filter_by(
|
| id=id,
|
| rand_user_id=rand_user_id
|
| ).first_or_404()
|
|
|
| db.session.delete(translate)
|
| db.session.commit()
|
| return APIResponse.success(message="删除成功")
|
|
|
|
|
| class TranslateRandDownloadResource(Resource):
|
| def get(self):
|
| """下载临时用户翻译文件[^6]"""
|
| rand_user_id = request.args.get('rand_user_id')
|
| records = Translate.query.filter_by(
|
| rand_user_id=rand_user_id,
|
| status='done'
|
| ).all()
|
|
|
| zip_buffer = BytesIO()
|
| with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
|
| for record in records:
|
| if os.path.exists(record.target_filepath):
|
| zip_file.write(
|
| record.target_filepath,
|
| os.path.basename(record.target_filepath)
|
| )
|
|
|
| zip_buffer.seek(0)
|
| return send_file(
|
| zip_buffer,
|
| mimetype='application/zip',
|
| as_attachment=True,
|
| download_name=f"temp_translations_{datetime.now().strftime('%Y%m%d')}.zip"
|
| )
|
|
|
|
|
| class Doc2xCheckResource(Resource):
|
| def post(self):
|
| """检查Doc2x接口[^7]"""
|
| secret_key = request.json.get('doc2x_secret_key')
|
|
|
| if secret_key == "valid_key_123":
|
| return APIResponse.success(message="接口正常")
|
| return APIResponse.error("无效密钥", 400)
|
|
|