| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| from typing import List, Optional |
|
|
| import fire |
|
|
| from llama import Llama, Dialog |
|
|
| import sys |
| from ast import literal_eval |
|
|
| import json |
|
|
| def tuple2dialog(x): |
| systemprompt = {"role": "system", "content": "Always answer with a single word 'True' or 'False'"} |
| userprompt = {"role": "user", "content": f"Consider the string '{x[2]}'. True or False: {x[1]}"} |
| return [userprompt] |
|
|
| |
| |
| |
|
|
| |
| def tuple_add_dialog(x): |
| systemprompt = {"role": "system", "content": "Always answer with a single word 'True' or 'False'"} |
| userprompt = {"role": "user", "content": f"Consider the string '{x[2]}'. True or False: {x[1]}"} |
| return (x[0],x[1],x[2],x[3],[userprompt]) |
|
|
| def main( |
| ckpt_dir: str, |
| tokenizer_path: str, |
| benchmark_path: str, |
| temperature: float = 0.6, |
| top_p: float = 0.9, |
| max_seq_len: int = 512, |
| max_batch_size: int = 8, |
| max_gen_len: Optional[int] = None, |
| tuple: Optional[bool] = False, |
| ): |
| """ |
| Entry point of the program for generating text using a pretrained model. |
| |
| Args: |
| ckpt_dir (str): The directory containing checkpoint files for the pretrained model. |
| tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding. |
| benchmark_path (str): The path to the benchmark e.g. benchmark/cl23.txt. |
| temperature (float, optional): The temperature value for controlling randomness in generation. |
| Defaults to 0.6. |
| top_p (float, optional): The top-p sampling parameter for controlling diversity in generation. |
| Defaults to 0.9. |
| max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512. |
| max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8. |
| max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be |
| set to the model's max sequence length. Defaults to None. |
| """ |
|
|
| |
| |
| |
| |
| |
| |
| benchmark_stream = map(lambda z:tuple_add_dialog(literal_eval(z)), open(benchmark_path)) |
| |
| |
| benchmark_by_5 = zip(*(benchmark_stream,) * 5) |
| |
| |
| |
| ''' |
| def gen_dialogs(): |
| for x in zip(range(100),benchmark_by_5): |
| dialog = map(tuple2dialog,x[1]) |
| yield list(dialog) |
| ''' |
|
|
| |
| def gen_dialog_tuples(): |
| for x in zip(range(100),benchmark_by_5): |
| yield list(x[1]) |
|
|
| generator = Llama.build( |
| ckpt_dir=ckpt_dir, |
| tokenizer_path=tokenizer_path, |
| max_seq_len=max_seq_len, |
| max_batch_size=max_batch_size, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| for dtuple in gen_dialog_tuples(): |
| dialogs = [z[4] for z in dtuple] |
| results = generator.chat_completion( |
| dialogs, |
| max_gen_len=max_gen_len, |
| temperature=temperature, |
| top_p=top_p, |
| ) |
| for tpl, result in zip(dtuple, results): |
| if tuple: |
| t0 = tpl[0] |
| t1 = tpl[1] |
| t2 = tpl[2] |
| t3 = tpl[3] |
| t4 = tpl[4] |
| |
| t5 = json.dumps(result['generation']['content']) |
| |
| print(f'("{t0}","{t1}","{t2}","{t3}",{t4},"{t5}")') |
| else: |
| |
| print(f"{msg['role'].capitalize()}: {msg['content']}\n") |
| print( |
| f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}" |
| ) |
| print("\n==================================\n") |
|
|
|
|
| if __name__ == "__main__": |
| fire.Fire(main) |
|
|