AlsuGibadullina's picture
Update app.py
9c49ab5 verified
import os
import json
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import pandas as pd
from huggingface_hub import InferenceClient
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
HF_MODELS = {
"Qwen2.5-32B-Instruct": "Qwen/Qwen2.5-32B-Instruct",
"DeepSeek-R1-Distill-Qwen-32B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"DeepSeek-R1-Distill-Llama-8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Mistral-Small-24B-Instruct": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
}
@dataclass
class RequirementResult:
original_requirement: str
model_name: str
original_score: Optional[float]
refactored_requirement: str
found_issues: List[str]
refactoring_time_sec: float
status: str
raw_response: str
SYSTEM_PROMPT = """
Ты — эксперт по инженерии требований и системному анализу.
Твоя задача:
1. Проанализировать исходное требование.
2. Найти ошибки и проблемы качества в исходном требовании.
3. Оценить качество исходного требования числом от 0 до 1:
- 0 = очень плохое требование
- 1 = качественное требование
4. Выполнить рефакторинг требования, улучшив его.
ВАЖНО:
- Отвечай только на русском языке.
- Верни только корректный JSON.
- Не добавляй markdown, пояснений или текста вне JSON.
Формат ответа:
{
"original_score": 0.0,
"refactored_requirement": "строка",
"found_issues": ["ошибка 1", "ошибка 2"],
"comment": "краткое пояснение"
}
Правила:
- original_score должен быть числом от 0 до 1.
- refactored_requirement должен быть одной улучшенной формулировкой.
- found_issues должен содержать список найденных проблем в исходном требовании.
- Все значения должны быть на русском языке.
""".strip()
def build_user_prompt(requirement: str) -> str:
return f"""
Проанализируй и отрефактори следующее требование:
{requirement}
""".strip()
def make_hf_client() -> InferenceClient:
if HF_TOKEN:
return InferenceClient(token=HF_TOKEN)
return InferenceClient()
def safe_json_extract(text: str) -> Dict[str, Any]:
text = (text or "").strip()
try:
return json.loads(text)
except Exception:
pass
text = re.sub(r"^```(?:json)?", "", text, flags=re.IGNORECASE).strip()
text = re.sub(r"```$", "", text).strip()
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
if match:
return json.loads(match.group(0))
raise ValueError("Модель не вернула корректный JSON")
def normalize_score(value: Any) -> Optional[float]:
try:
score = float(value)
score = max(0.0, min(1.0, score))
return round(score, 3)
except Exception:
return None
def normalize_result(data: Dict[str, Any]) -> Dict[str, Any]:
issues = data.get("found_issues", [])
if not isinstance(issues, list):
issues = [str(issues)]
issues = [str(x).strip() for x in issues if str(x).strip()]
return {
"original_score": normalize_score(data.get("original_score")),
"refactored_requirement": str(data.get("refactored_requirement", "")).strip(),
"found_issues": issues,
"comment": str(data.get("comment", "")).strip(),
}
def parse_requirements_text(raw_text: str) -> List[str]:
raw_text = (raw_text or "").strip()
if not raw_text:
return []
lines = [line.strip() for line in raw_text.splitlines() if line.strip()]
cleaned = []
for line in lines:
line = re.sub(r"^\d+[\).\s-]+", "", line).strip()
line = re.sub(r"^[-•*]\s*", "", line).strip()
if line:
cleaned.append(line)
return cleaned
def load_requirements_from_file(file_obj) -> List[str]:
if file_obj is None:
return []
path = file_obj.name
ext = os.path.splitext(path)[1].lower()
if ext == ".csv":
df = pd.read_csv(path)
possible_columns = [
"requirement", "requirements", "text", "description",
"требование", "требования", "текст", "описание"
]
for col in possible_columns:
if col in df.columns:
return [str(x).strip() for x in df[col].dropna().tolist() if str(x).strip()]
first_col = df.columns[0]
return [str(x).strip() for x in df[first_col].dropna().tolist() if str(x).strip()]
if ext == ".txt":
with open(path, "r", encoding="utf-8") as f:
return parse_requirements_text(f.read())
raise ValueError("Поддерживаются только .txt и .csv")
def call_model(
model_id: str,
requirement: str,
temperature: float,
max_tokens: int,
) -> Tuple[str, Dict[str, Any]]:
client = make_hf_client()
response = client.chat.completions.create(
model=model_id,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_user_prompt(requirement)},
],
temperature=temperature,
max_tokens=max_tokens,
)
text = response.choices[0].message.content
parsed = safe_json_extract(text)
return text, normalize_result(parsed)
def process_single_requirement(
requirement: str,
model_label: str,
temperature: float,
max_tokens: int,
) -> RequirementResult:
start = time.perf_counter()
try:
if model_label not in HF_MODELS:
raise RuntimeError(f"Неизвестная модель: {model_label}")
raw_response, parsed = call_model(
model_id=HF_MODELS[model_label],
requirement=requirement,
temperature=temperature,
max_tokens=max_tokens,
)
elapsed = round(time.perf_counter() - start, 3)
return RequirementResult(
original_requirement=requirement,
model_name=model_label,
original_score=parsed["original_score"],
refactored_requirement=parsed["refactored_requirement"],
found_issues=parsed["found_issues"],
refactoring_time_sec=elapsed,
status="ok",
raw_response=raw_response,
)
except Exception as e:
elapsed = round(time.perf_counter() - start, 3)
return RequirementResult(
original_requirement=requirement,
model_name=model_label,
original_score=None,
refactored_requirement="",
found_issues=[],
refactoring_time_sec=elapsed,
status=f"error: {str(e)}",
raw_response="",
)
def build_results_dataframe(results: List[RequirementResult]) -> pd.DataFrame:
rows = []
for r in results:
rows.append({
"Изначальное требование": r.original_requirement,
"Модель": r.model_name,
"Оценка изначального требования": r.original_score,
"Версия требования после рефакторинга": r.refactored_requirement,
"Найденные ошибки в изначальном требовании": "; ".join(r.found_issues),
"Время рефакторинга": r.refactoring_time_sec,
})
return pd.DataFrame(rows)
def save_results(df: pd.DataFrame, raw_results: List[RequirementResult]) -> Tuple[str, str]:
csv_path = "results_table.csv"
json_path = "results_raw.json"
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
with open(json_path, "w", encoding="utf-8") as f:
json.dump([asdict(x) for x in raw_results], f, ensure_ascii=False, indent=2)
return csv_path, json_path
def run_benchmark(
raw_requirements: str,
uploaded_file,
selected_models: List[str],
temperature: float,
max_tokens: int,
max_parallel_calls: int,
):
requirements = []
if raw_requirements.strip():
requirements.extend(parse_requirements_text(raw_requirements))
if uploaded_file is not None:
requirements.extend(load_requirements_from_file(uploaded_file))
unique_requirements = []
seen = set()
for req in requirements:
if req not in seen:
unique_requirements.append(req)
seen.add(req)
if not unique_requirements:
raise gr.Error("Добавь требования текстом или загрузи CSV/TXT файл.")
if not selected_models:
raise gr.Error("Выбери хотя бы одну модель.")
results: List[RequirementResult] = []
futures = []
with ThreadPoolExecutor(max_workers=max_parallel_calls) as executor:
for req in unique_requirements:
for model_name in selected_models:
futures.append(
executor.submit(
process_single_requirement,
requirement=req,
model_label=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
)
for future in as_completed(futures):
results.append(future.result())
results.sort(key=lambda x: (x.original_requirement, x.model_name))
df = build_results_dataframe(results)
csv_path, json_path = save_results(df, results)
stats = (
f"Обработано требований: {len(unique_requirements)}\n"
f"Выбрано моделей: {len(selected_models)}\n"
f"Всего запусков: {len(results)}"
)
return stats, df, csv_path, json_path
def preview_requirements(raw_requirements: str, uploaded_file):
requirements = []
if raw_requirements.strip():
requirements.extend(parse_requirements_text(raw_requirements))
if uploaded_file is not None:
try:
requirements.extend(load_requirements_from_file(uploaded_file))
except Exception as e:
return f"Ошибка чтения файла: {e}"
unique_requirements = []
seen = set()
for req in requirements:
if req not in seen:
unique_requirements.append(req)
seen.add(req)
if not unique_requirements:
return "Требования не найдены."
preview = "\n".join(f"{i+1}. {req}" for i, req in enumerate(unique_requirements[:20]))
if len(unique_requirements) > 20:
preview += f"\n... ещё {len(unique_requirements) - 20}"
return f"Найдено требований: {len(unique_requirements)}\n\n{preview}"
with gr.Blocks(title="Рефакторинг требований с помощью LLM") as demo:
gr.Markdown(
"""
# Сравнение моделей для анализа и рефакторинга требований
Приложение принимает:
- список требований, где каждое требование идет с новой строки;
- или CSV-файл с требованиями.
Для каждого требования каждая выбранная модель:
- оценивает исходное требование;
- делает рефакторинг;
- находит ошибки;
- измеряется время выполнения.
"""
)
with gr.Row():
with gr.Column(scale=2):
raw_requirements = gr.Textbox(
label="Список требований",
lines=14,
placeholder="Каждое новое требование — с новой строки"
)
uploaded_file = gr.File(
label="Или загрузи TXT / CSV файл",
file_types=[".txt", ".csv"]
)
with gr.Column(scale=1):
model_selector = gr.CheckboxGroup(
choices=list(HF_MODELS.keys()),
value=[
"Qwen2.5-32B-Instruct",
"DeepSeek-R1-Distill-Qwen-32B",
"Llama-3.1-8B-Instruct",
],
label="Модели"
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.1,
label="Температура"
)
max_tokens = gr.Slider(
minimum=256,
maximum=2048,
value=1024,
step=128,
label="Максимум токенов"
)
max_parallel_calls = gr.Slider(
minimum=1,
maximum=8,
value=3,
step=1,
label="Параллельных запросов"
)
preview_btn = gr.Button("Проверить входные данные")
run_btn = gr.Button("Запустить", variant="primary")
preview_box = gr.Textbox(
label="Предпросмотр",
lines=10
)
preview_btn.click(
fn=preview_requirements,
inputs=[raw_requirements, uploaded_file],
outputs=preview_box,
)
stats_box = gr.Textbox(label="Статистика")
results_table = gr.Dataframe(label="Результаты", wrap=True, interactive=False)
with gr.Row():
results_csv = gr.File(label="Скачать CSV")
results_json = gr.File(label="Скачать JSON")
run_btn.click(
fn=run_benchmark,
inputs=[
raw_requirements,
uploaded_file,
model_selector,
temperature,
max_tokens,
max_parallel_calls,
],
outputs=[
stats_box,
results_table,
results_csv,
results_json,
],
)
if __name__ == "__main__":
demo.launch()