| import streamlit as st |
| import text_transformation_tools as ttt |
| from transformers import pipeline |
| import plotly.express as px |
|
|
|
|
| def read_pdf(file): |
| text = ttt.pdf_to_text(uploaded_file) |
|
|
| return text |
|
|
| def analyze_text(paragraphs, topics, model, mode, min_chars, prob): |
|
|
| with st.spinner('Loading model'): |
| classifier = pipeline('zero-shot-classification', model=model) |
|
|
| relevant_parts = {} |
|
|
| for topic in topics: |
| relevant_parts[topic] = [] |
|
|
| if mode == 'paragraphs': |
| text = paragraphs |
| elif mode == 'sentences': |
| text = [] |
| for paragraph in paragraphs: |
| for sentence in paragraph.split('.'): |
| text.append(sentence) |
|
|
| min_chars = min_chars |
| min_score = prob |
|
|
| with st.spinner('Analyzing text...'): |
| counter = 0 |
| counter_rel = 0 |
| counter_tot = len(text) |
|
|
| with st.empty(): |
|
|
| for sequence_to_classify in text: |
| |
| cleansed_sequence = sequence_to_classify.replace('\n', '').replace(' ', ' ') |
|
|
| if len(cleansed_sequence) >= min_chars: |
|
|
|
|
| classified = classifier(cleansed_sequence, topics, multi_label=True) |
|
|
| for idx in range(len(classified['scores'])): |
| if classified['scores'][idx] >= min_score: |
| relevant_parts[classified['labels'][idx]].append(sequence_to_classify) |
| counter_rel += 1 |
|
|
| counter += 1 |
|
|
| st.write('Analyzed {} of {} {}. Found {} relevant {} so far.'.format(counter, counter_tot, mode, counter_rel, mode)) |
|
|
|
|
| return relevant_parts |
|
|
|
|
| CHOICES = { |
| 'facebook/bart-large-mnli': 'bart-large-mnli (very slow, english)', |
| 'valhalla/distilbart-mnli-12-1': 'distilbart-mnli-12-1 (slow, english)', |
| 'BaptisteDoyen/camembert-base-xnli': 'camembert-base-xnli (fast, english)', |
| 'typeform/mobilebert-uncased-mnli': 'mobilebert-uncased-mnli (very fast, english)', |
| 'Sahajtomar/German_Zeroshot': 'German_Zeroshot (slow, german)', |
| 'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7': 'mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (fast, multilingual)'} |
| def format_func(option): |
| return CHOICES[option] |
|
|
| st.header('File and topics') |
| uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf") |
| topics = st.text_input(label='Enter coma separated sustainability topics of interest.', value = 'human rights, sustainability') |
|
|
|
|
| st.header('Settings') |
| col1, col2 = st.columns(2) |
|
|
| with col1: |
| model = st.selectbox("Select model used to analyze pdf.", options=list(CHOICES.keys()), format_func=format_func, index=3) |
| mode = st.selectbox(label='Chose if you want to detect relevant paragraphs or sentences.', options=['paragraphs', 'sentences']) |
| with col2: |
| min_chars = st.number_input(label='Minimum number of characters to analyze in a text', min_value=0, max_value=500, value=20) |
| probability = st.number_input(label='Minimum probability of being relevant to accept (in percent)', min_value=0, max_value=100, value=90)/100 |
|
|
| topics = topics.split(',') |
| topics = [topic.strip() for topic in topics] |
|
|
| st.header('Analyze PDF') |
|
|
| if st.button('Analyze PDF'): |
| with st.spinner('Reading PDF...'): |
| text = read_pdf(uploaded_file) |
| page_count = ttt.count_pages(uploaded_file) |
| language = ttt.detect_language(' '.join(text))[0] |
| st.subheader('Overview') |
| st.write('Our pdf reader detected {} pages and {} paragraphs. We assume that the language of this text is "{}".'.format(page_count, len(text), language)) |
| |
| st.subheader('Analysis') |
| relevant_parts = analyze_text(text, topics, model, mode, min_chars, probability) |
|
|
| counts = [len(relevant_parts[topic]) for topic in topics] |
|
|
| fig = px.bar(x=topics, y=counts, title='Found {}s of Relevance'.format(mode)) |
|
|
| st.plotly_chart(fig) |
|
|
| st.subheader('Relevant Passages') |
| st.write(relevant_parts) |
|
|
|
|