| import argparse |
| import json |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForSequenceClassification, |
| pipeline, |
| ) |
|
|
|
|
| def chunk_and_classify(text, classifier, tokenizer, max_len=512, stride=50): |
| """ |
| Splits a given text into overlapping chunks, classifies each chunk using a |
| provided classifier, and computes the average classification scores for |
| each label across all chunks. |
| |
| Args: |
| text (str): The input text to be chunked and classified. |
| classifier (Callable): A function or model that takes a text input and |
| returns a list of dictionaries containing classification labels and scores. |
| tokenizer (Callable): A tokenizer function or model that tokenizes the input |
| text and provides token IDs. |
| max_len (int, optional): The maximum length of each chunk in tokens. Defaults to 512. |
| stride (int, optional): The number of tokens to overlap between consecutive chunks. |
| Defaults to 50. |
| |
| Returns: |
| dict: A dictionary where keys are classification labels and values are the |
| average scores for each label across all chunks. |
| """ |
| |
| tokens = tokenizer(text, return_tensors="pt")["input_ids"][0] |
| chunks = [] |
| for i in range(0, tokens.size(0), max_len - stride): |
| chunk_ids = tokens[i : i + max_len] |
| chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True)) |
| if i + max_len >= tokens.size(0): |
| break |
|
|
| |
| chunk_scores = [] |
| for chunk in chunks: |
| scores = classifier(chunk)[0] |
| chunk_scores.append({d["label"]: d["score"] for d in scores}) |
|
|
| |
| avg_scores = { |
| label: sum(s[label] for s in chunk_scores) / len(chunk_scores) |
| for label in chunk_scores[0] |
| } |
| return avg_scores |
|
|
|
|
| def main(): |
|
|
| |
| |
|
|
| default_dir = "~/Code/Huggingface-metadata-project/BERTley/checkpoint-3486" |
| parser = argparse.ArgumentParser( |
| description="Run inference on a trained BERT metadata classifier" |
| ) |
| parser.add_argument( |
| "--model_dir", |
| type=str, |
| default=default_dir, |
| help="Directory where your trained model and config live", |
| ) |
| group = parser.add_mutually_exclusive_group(required=True) |
| group.add_argument("--text", type=str, help="Raw text string to classify") |
| group.add_argument( |
| "--input_file", |
| type=str, |
| help="Path to a .txt file containing the document to classify", |
| ) |
| args = parser.parse_args() |
|
|
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(args.model_dir) |
| model = AutoModelForSequenceClassification.from_pretrained(args.model_dir) |
|
|
| |
| classifier = pipeline( |
| "text-classification", |
| model=model, |
| tokenizer=tokenizer, |
| return_all_scores=True, |
| ) |
|
|
| |
| if args.input_file: |
| text = open(args.input_file, "r", encoding="utf-8").read() |
| else: |
| text = args.text |
|
|
| |
| |
| tokens = tokenizer(text, return_tensors="pt")["input_ids"] |
| if tokens.size(1) <= 512: |
| result = classifier(text)[0] |
| scores = {d["label"]: d["score"] for d in result} |
| else: |
| scores = chunk_and_classify(text, classifier, tokenizer) |
|
|
| |
| print(json.dumps(scores, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|