| import torch
|
| import torch.nn.functional as F
|
| from model import MiniText
|
| import random
|
|
|
|
|
|
|
|
|
| MODEL_PATH = "minitext.pt"
|
| DEVICE = "cpu"
|
|
|
|
|
|
|
|
|
| model = MiniText().to(DEVICE)
|
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
|
| model.eval()
|
|
|
|
|
|
|
|
|
| def sample_logits(logits, temperature=1.0, top_k=0):
|
| logits = logits / temperature
|
|
|
| if top_k > 0:
|
| values, _ = torch.topk(logits, top_k)
|
| min_val = values[:, -1].unsqueeze(-1)
|
| logits = torch.where(logits < min_val, torch.full_like(logits, -1e9), logits)
|
|
|
| probs = F.softmax(logits, dim=-1)
|
| return torch.multinomial(probs, 1).item()
|
|
|
|
|
|
|
|
|
| def generate(
|
| prompt="o",
|
| max_new_tokens=300,
|
| temperature=0.5,
|
| top_k=50,
|
| top_p=0.95,
|
| repetition_penalty=1.2,
|
| seed=None,
|
| h=None
|
| ):
|
| if seed is not None:
|
| torch.manual_seed(seed)
|
| random.seed(seed)
|
|
|
| bytes_in = list(prompt.encode("utf-8", errors="ignore"))
|
| output = bytes_in.copy()
|
|
|
|
|
| x = torch.tensor([bytes_in], dtype=torch.long, device=DEVICE)
|
| with torch.no_grad():
|
| _, h = model(x, h)
|
|
|
| last = x[:, -1:]
|
|
|
| for _ in range(max_new_tokens):
|
| with torch.no_grad():
|
| logits, h = model(last, h)
|
|
|
| next_byte = sample_logits(
|
| logits[:, -1],
|
| temperature=temperature,
|
| top_k=top_k
|
| )
|
|
|
| output.append(next_byte)
|
| last = torch.tensor([[next_byte]], device=DEVICE)
|
|
|
| return bytes(output).decode(errors="ignore"), h
|
|
|
| h = None
|
|
|
| print("MiniText-v1.5 Chat | digite 'exit' para sair")
|
|
|
| while True:
|
| user = input("usuario: ")
|
| if user.lower() == "quit":
|
| break
|
|
|
| prompt = f"usuario: {user}\nia: "
|
| text, h = generate(
|
| prompt=prompt,
|
| max_new_tokens=120,
|
| temperature=0.5,
|
| top_k=50,
|
| top_p=0.95,
|
| repetition_penalty=1.2,
|
| h=h
|
| )
|
|
|
| reply = text.split("ia:")[-1].strip()
|
| print("ia:", reply)
|
|
|
|
|