vectrayx-paper-code / eval /run_inference_lora.py
jsantillana's picture
Upload folder using huggingface_hub
4da4469 verified
#!/usr/bin/env python3
"""
Inferencia con LoRA aplicado correctamente.
Estrategia:
1. Carga el modelo base (nano_sft_v5.pt) con pesos completos
2. Inyecta LoRA en wq/wk/wv/wo
3. Carga los pesos LoRA desde final_lora_only.pt
4. Corre inferencia
Esto evita el problema del merge con tie_embeddings=True.
"""
import argparse
import json
import math
import sys
from pathlib import Path
import torch
import torch.nn as nn
import sentencepiece as spm
ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(ROOT))
from training_v2.model.transformer import VectraYXNano, ModelConfig
from training_v2.train.utils import load_checkpoint
# ─── LoRA (misma implementación que finetune_lora_tools.py) ──────────────────
class LoRALinear(nn.Module):
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
super().__init__()
self.linear = linear
self.rank = rank
self.scale = alpha / rank
in_f, out_f = linear.in_features, linear.out_features
self.lora_A = nn.Parameter(torch.empty(rank, in_f))
self.lora_B = nn.Parameter(torch.zeros(out_f, rank))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
for p in self.linear.parameters():
p.requires_grad_(False)
def forward(self, x):
base = self.linear(x)
lora = (x @ self.lora_A.to(x.device).T) @ self.lora_B.to(x.device).T
return base + lora * self.scale
def inject_lora(model, rank, alpha, targets=("wq", "wk", "wv", "wo")):
for name, module in model.named_modules():
for attr in targets:
if hasattr(module, attr):
orig = getattr(module, attr)
if isinstance(orig, nn.Linear):
setattr(module, attr, LoRALinear(orig, rank, alpha))
# Congelar todo excepto LoRA
for name, p in model.named_parameters():
if "lora_A" not in name and "lora_B" not in name:
p.requires_grad_(False)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"[lora] {trainable/1e3:.1f}K entrenables / {total/1e6:.2f}M total ({trainable/total*100:.2f}%)")
def load_lora_weights(path, model, device):
"""Carga solo los pesos lora_A y lora_B desde el checkpoint."""
ckpt = torch.load(path, map_location=device, weights_only=False)
lora_state = ckpt.get("lora_state_dict", {})
if not lora_state:
# Intentar cargar directamente si el formato es diferente
lora_state = {k: v for k, v in ckpt.items()
if "lora_A" in k or "lora_B" in k}
if not lora_state:
print(f"[warn] No se encontraron pesos LoRA en {path}")
return
missing, unexpected = model.load_state_dict(lora_state, strict=False)
lora_loaded = len(lora_state)
print(f"[lora] Cargados {lora_loaded} pesos LoRA | missing={len(missing)} unexpected={len(unexpected)}")
# ─── Prompts ──────────────────────────────────────────────────────────────────
SYSTEM_PROMPT = (
"Eres VectraYX, un asistente de ciberseguridad en español. "
"Tienes 5 herramientas MCP disponibles:\n"
"- nvd_get_cve(cve_id): obtener detalle de un CVE\n"
"- nvd_search(query, limit): buscar CVEs por palabra clave\n"
"- cisa_kev_check(cve_id): comprobar si un CVE está en el catálogo KEV\n"
"- otx_check_ioc(ioc_type, value): reputación de IOC (ip, domain, hash)\n"
"- bash_exec(cmd): ejecutar comando shell local\n"
"Cuando la pregunta requiera datos externos o ejecutar algo, emite EXACTAMENTE:\n"
'<|tool_call|>{"name":"<tool>","args":{...}}<|/tool_call|>\n'
"Si la pregunta es conversacional o conceptual, responde en prosa SIN llamar herramientas."
)
TEST_PROMPTS = [
("qué hora es", "bash_exec"),
("dame la fecha de hoy", "bash_exec"),
("quién soy", "bash_exec"),
("en qué directorio estoy", "bash_exec"),
("cuánta memoria libre hay", "bash_exec"),
("uso de disco", "bash_exec"),
("qué sistema operativo es", "bash_exec"),
("cuál es mi IP", "bash_exec"),
("lista los archivos aquí", "bash_exec"),
("qué puertos están escuchando", "bash_exec"),
("dame detalles de CVE-2021-44228", "nvd_get_cve"),
("busca CVEs de log4j", "nvd_search"),
("está CVE-2021-44228 en KEV", "cisa_kev_check"),
("es maliciosa la IP 45.155.205.12", "otx_check_ioc"),
("qué es un zero-day", None),
("hola, cómo estás", None),
]
def build_prompt(sp, question):
text = (f"<|system|>{SYSTEM_PROMPT}<|end|>"
f"<|user|>{question}<|end|>"
f"<|assistant|>")
return torch.tensor([sp.encode(text, out_type=int)], dtype=torch.long)
@torch.no_grad()
def generate(model, input_ids, sp, max_new=100, temperature=0.7,
top_k=40, top_p=0.9, repeat_penalty=1.3):
device = next(model.parameters()).device
ids = input_ids.to(device)
end_id = sp.piece_to_id("<|end|>")
generated = []
for _ in range(max_new):
logits, _ = model(ids)
logits = logits[0, -1, :]
# Repeat penalty
if repeat_penalty != 1.0 and generated:
for tok in set(generated[-50:]):
logits[tok] /= repeat_penalty
logits = logits / temperature
# Top-k
if top_k > 0:
vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < vals[-1]] = float('-inf')
# Top-p
probs = torch.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=0)
mask = cumsum - sorted_probs > top_p
sorted_probs[mask] = 0
sorted_probs /= sorted_probs.sum() + 1e-8
next_tok = sorted_idx[torch.multinomial(sorted_probs, 1)].item()
generated.append(next_tok)
ids = torch.cat([ids, torch.tensor([[next_tok]], device=device)], dim=1)
if next_tok == end_id:
break
return sp.decode(generated)
def run_benchmark(model, sp, out_path):
results = []
correct = 0
total_tool = 0
print("\n" + "="*70)
print("BENCHMARK — Nano LoRA v3")
print("="*70)
for question, expected_tool in TEST_PROMPTS:
input_ids = build_prompt(sp, question)
response = generate(model, input_ids, sp)
# Detectar tool-call
detected_tool = None
detected_args = None
if "<|tool_call|>" in response and "<|/tool_call|>" in response:
try:
s = response.index("<|tool_call|>") + len("<|tool_call|>")
e = response.index("<|/tool_call|>")
call = json.loads(response[s:e])
detected_tool = call.get("name")
detected_args = call.get("args", {})
except Exception:
pass
if expected_tool is not None:
total_tool += 1
ok = detected_tool == expected_tool
if ok:
correct += 1
status = "✅" if ok else "❌"
else:
ok = detected_tool is None
status = "✅" if ok else "⚠️"
print(f"\n[{status}] Q: {question}")
print(f" Expected: {expected_tool or 'ninguna'}")
print(f" Got: {detected_tool or 'ninguna'}")
if detected_args:
print(f" Args: {detected_args}")
# Mostrar respuesta limpia (solo ASCII+español)
clean = ''.join(c if ord(c) < 0x4000 else '?' for c in response[:150])
print(f" Raw: {clean.strip()}")
results.append({
"question": question,
"expected_tool": expected_tool,
"detected_tool": detected_tool,
"detected_args": detected_args,
"response": response,
"correct": ok,
})
b4 = correct / total_tool if total_tool > 0 else 0
print("\n" + "="*70)
print(f"B4 SCORE: {correct}/{total_tool} = {b4:.3f}")
print("="*70)
with open(out_path, "w", encoding="utf-8") as f:
json.dump({"b4_score": b4, "correct": correct,
"total_tool": total_tool, "results": results},
f, indent=2, ensure_ascii=False)
print(f"\n[saved] {out_path}")
return b4
def main():
p = argparse.ArgumentParser()
p.add_argument("--base-checkpoint", required=True,
help="Checkpoint base del modelo (nano_sft_v5.pt)")
p.add_argument("--lora-checkpoint", required=True,
help="Checkpoint LoRA (final_lora_only.pt)")
p.add_argument("--config", required=True)
p.add_argument("--tokenizer", required=True)
p.add_argument("--lora-rank", type=int, default=16)
p.add_argument("--lora-alpha", type=float, default=32.0)
p.add_argument("--device", default="cpu")
p.add_argument("--out", default="/tmp/nano_lora_bench_v2.json")
args = p.parse_args()
# 1. Cargar modelo base
cfg = ModelConfig.from_json(args.config)
model = VectraYXNano(cfg).to(args.device)
print(f"[model] {model.num_params()/1e6:.2f}M params")
load_checkpoint(args.base_checkpoint, model, map_location=args.device)
print(f"[base] cargado: {args.base_checkpoint}")
# 2. Inyectar LoRA
inject_lora(model, rank=args.lora_rank, alpha=args.lora_alpha)
model = model.to(args.device)
# 3. Cargar pesos LoRA
load_lora_weights(args.lora_checkpoint, model, args.device)
print(f"[lora] cargado: {args.lora_checkpoint}")
model.eval()
# 4. Tokenizer
sp = spm.SentencePieceProcessor()
sp.load(args.tokenizer)
# 5. Benchmark
run_benchmark(model, sp, args.out)
if __name__ == "__main__":
main()