FD900's picture
Update agent.py
e52cce4 verified
raw
history blame
1.79 kB
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# Load FLAN-T5 base model
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
# GAIA system prompt (applied manually)
system_prompt = (
"You are a general AI assistant. I will ask you a question. Report your thoughts, "
"and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. "
"YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of "
"numbers and/or strings. If you are asked for a number, don't use comma to write your number "
"neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, "
"don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. "
"If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list "
"is a number or a string.\n"
)
class BasicAgent:
def __init__(self):
print("Flan-T5 GAIA agent initialized.")
def __call__(self, question: str) -> str:
prompt = system_prompt + "\nQuestion: " + question
try:
result = generator(prompt, max_length=256, do_sample=False)[0]['generated_text']
except Exception as e:
return f"ERROR: {e}"
# Extract FINAL ANSWER
final_answer = "None"
if "FINAL ANSWER:" in result:
final_answer = result.split("FINAL ANSWER:")[-1].strip()
else:
final_answer = result.strip()
return final_answer