| |
| |
|
|
| import requests, json |
| from collections import namedtuple |
| from functools import lru_cache |
| from typing import List |
| from dataclasses import dataclass, field |
| from datetime import datetime as dt |
| import streamlit as st |
|
|
| from codetiming import Timer |
| from transformers import AutoTokenizer |
|
|
| from source import Source, Summary |
| from scrape_sources import stub as stb |
|
|
|
|
|
|
| @dataclass |
| class Digestor: |
| timer: Timer |
| cache: bool = True |
| text: str = field(default="no_digest") |
| stubs: List = field(default_factory=list) |
| |
| |
| user_choices: List =field(default_factory=list) |
| |
| summaries: List = field(default_factory=list) |
| |
| |
|
|
| digest_meta:namedtuple( |
| "digestMeta", |
| [ |
| 'digest_time', |
| 'number_articles', |
| 'digest_length', |
| 'articles_per_cluster' |
| ]) = None |
|
|
| |
| token_limit: int = 1024 |
| word_limit: int = 400 |
| SUMMARIZATION_PARAMETERS = { |
| "do_sample": False, |
| "use_cache": cache, |
| } |
|
|
| |
| API_URL = "https://api-inference.huggingface.co/models/sshleifer/distilbart-cnn-12-6" |
| headers = {"Authorization": f"""Bearer {st.secrets['ato']}"""} |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def relevance(self, summary): |
| return len(set(self.user_choices) & set(summary.cluster_list)) |
|
|
| def digest(self): |
| """Retrieves all data for user-chosen articles, builds summary object list""" |
| |
| self.timer.timers.clear() |
| |
| with Timer(name=f"digest_time", text="Total digest time: {seconds:.4f} seconds"): |
| |
| |
| for stub in self.stubs: |
| |
| |
| if not isinstance(stub, stb): |
| self.summaries.append(stub) |
| else: |
| |
| summary_data: List |
| |
| text, summary_data = stub.source.retrieve_article(stub) |
| |
| |
| if text != None and summary_data != None: |
| |
| with Timer(name=f"{stub.hed}_chunk_time", logger=None): |
| chunk_list = self.chunk_piece(text, self.word_limit, stub.source.source_summarization_checkpoint) |
| |
| with Timer(name=f"{stub.hed}_summary_time", text="Whole article summarization time: {:.4f} seconds"): |
| summary = self.perform_summarization( |
| stub.hed, |
| chunk_list, |
| self.API_URL, |
| self.headers, |
| cache = self.cache, |
| ) |
| |
| |
| |
|
|
| self.summaries.append( |
| Summary( |
| source=summary_data[0], |
| cluster_list=summary_data[1], |
| link_ext=summary_data[2], |
| hed=summary_data[3], |
| dek=summary_data[4], |
| date=summary_data[5], |
| authors=summary_data[6], |
| original_length = summary_data[7], |
| summary_text=summary, |
| summary_length=len(' '.join(summary).split(' ')), |
| chunk_time=self.timer.timers[f'{stub.hed}_chunk_time'], |
| query_time=self.timer.timers[f"{stub.hed}_query_time"], |
| mean_query_time=self.timer.timers.mean(f'{stub.hed}_query_time'), |
| summary_time=self.timer.timers[f'{stub.hed}_summary_time'], |
| |
| ) |
| ) |
| else: |
| print("Null article") |
|
|
|
|
| |
| self.summaries.sort(key=self.relevance, reverse=True) |
|
|
| |
| def query(self, payload, API_URL, headers): |
| """Performs summarization inference API call.""" |
| data = json.dumps(payload) |
| response = requests.request("POST", API_URL, headers=headers, data=data) |
| return json.loads(response.content.decode("utf-8")) |
|
|
|
|
| def chunk_piece(self, piece, limit, tokenizer_checkpoint, include_tail=False): |
| """Breaks articles into chunks that will fit the desired token length limit""" |
| |
| words = len(piece.split(' ')) |
| |
| |
| base_range = [i*limit for i in range(words//limit+1)] |
| |
| |
| |
| if include_tail or base_range == [0]: |
| base_range.append(base_range[-1]+words%limit) |
| |
| range_list = [i for i in zip(base_range,base_range[1:])] |
| |
|
|
| |
| fractured = piece.split(' ') |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) |
| chunk_list = [] |
| |
| |
| for i, j in range_list: |
| if (tokenized_len := len(tokenizer(chunk := ' '.join(fractured[i:j])))) <= self.token_limit: |
| chunk_list.append(chunk) |
| else: |
| chunk_list.append(' '.join(chunk.split(' ')[: self.token_limit - tokenized_len ])) |
| chunk_list = [i.replace(' . ','. ') for i in chunk_list] |
| return chunk_list |
|
|
|
|
|
|
| |
| def perform_summarization(self, stubhead, chunklist : List[str], API_URL: str, headers: None, cache=True) -> List[str]: |
| """For each in chunk_list, appends result of query(chunk) to list collection_bin.""" |
| collection_bin = [] |
| repeat = 0 |
| |
| |
| for chunk in chunklist: |
| safe = False |
| summarized_chunk = None |
| with Timer(name=f"{stubhead}_query_time", logger=None): |
| while not safe and repeat < 4: |
| try: |
| summarized_chunk = self.query( |
| { |
| "inputs": str(chunk), |
| "parameters": self.SUMMARIZATION_PARAMETERS |
| }, |
| API_URL, |
| headers, |
| )[0]['summary_text'] |
| safe = True |
| except Exception as e: |
| print("Summarization error, repeating...") |
| print(e) |
| repeat+=1 |
| print(summarized_chunk) |
| if summarized_chunk is not None: |
| collection_bin.append(summarized_chunk) |
| return collection_bin |
| |
|
|
|
|
| |
| def build_digest(self) -> str: |
| """Called to show the digest. Also creates data dict for digest and summaries.""" |
| |
| |
| |
| |
| |
| digest = [] |
| for each in self.summaries: |
| digest.append(' '.join(each.summary_text)) |
| |
| self.text = '\n\n'.join(digest) |
|
|
| |
| out_data = {} |
| t = dt.now() |
| datetime_str = f"""{t.hour:.2f}:{t.minute:.2f}:{t.second:.2f}""" |
| choices_str = ', '.join(self.user_choices) |
| digest_str = '\n\t'.join(digest) |
| |
| |
| |
| |
| |
| |
| |
| summaries = { |
| |
| c: { |
| |
| k._fields[i]:p if k._fields[i]!='source' |
| else |
| { |
| 'name': k.source.source_name, |
| 'source_url': k.source.source_url, |
| 'Summarization" Checkpoint': k.source.source_summarization_checkpoint, |
| 'NER Checkpoint': k.source.source_ner_checkpoint, |
| } for i,p in enumerate(k) |
| } for c,k in enumerate(self.summaries)} |
|
|
| out_data['timestamp'] = datetime_str |
| out_data['article_count'] = len(self.summaries) |
| out_data['digest_length'] = len(digest_str.split(" ")) |
| out_data['sum_params'] = { |
| 'token_limit':self.token_limit, |
| 'word_limit':self.word_limit, |
| 'params':self.SUMMARIZATION_PARAMETERS, |
| } |
| out_data['summaries'] = summaries |
|
|
| return out_data |