| from tclogger import logger |
| from transformers import AutoTokenizer |
|
|
| from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED |
|
|
|
|
| class TokenChecker: |
| def __init__(self, input_str: str, model: str): |
| self.input_str = input_str |
|
|
| if model in MODEL_MAP.keys(): |
| self.model = model |
| else: |
| self.model = "nous-mixtral-8x7b" |
|
|
| self.model_fullname = MODEL_MAP[self.model] |
|
|
| |
| GATED_MODEL_MAP = { |
| "llama3-70b": "NousResearch/Meta-Llama-3-70B", |
| "gemma-7b": "unsloth/gemma-7b", |
| "mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2", |
| "mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1", |
| } |
| if self.model in GATED_MODEL_MAP.keys(): |
| self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model]) |
| else: |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname) |
|
|
| def count_tokens(self): |
| token_count = len(self.tokenizer.encode(self.input_str)) |
| logger.note(f"Prompt Token Count: {token_count}") |
| return token_count |
|
|
| def get_token_limit(self): |
| return TOKEN_LIMIT_MAP[self.model] |
|
|
| def get_token_redundancy(self): |
| return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens()) |
|
|
| def check_token_limit(self): |
| if self.get_token_redundancy() <= 0: |
| raise ValueError( |
| f"Prompt exceeded token limit: {self.count_tokens()} > {self.get_token_limit()}" |
| ) |
| return True |
|
|