| import numpy as np |
| import pandas as pd |
| import torch |
| from transformers import GPT2Tokenizer |
| from pathlib import Path |
| import streamlit as st |
| from typing import List, Dict, Any, Callable |
| from pred import * |
| from load_data import * |
|
|
| def main(): |
| tokenizer = GPT2Tokenizer.from_pretrained('gpt2', add_bos_token=True) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| encoder = Encoder(h=64,n=2, e=64, a=4, o=64).to(device) |
| decoder = Decoder(h=64,n=2, e=64, a=4, o=50257).to(device) |
| model = Seq2Seq(encoder, decoder).to(device) |
|
|
| checkpoint = torch.load('./seq2seq_checkpoint.pt', weights_only=True, map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| st.title("Footy Commentary Generator") |
| |
| st.sidebar.header("Configuration") |
| |
| tab_selection = st.sidebar.radio( |
| "Select Input Method:", |
| ["Random Sample from Test Set", "Custom Input"] |
| ) |
| |
| st.sidebar.header("Decoding Configuration") |
| st.session_state.decoding_mode = st.sidebar.selectbox( |
| "Decoding Mode", |
| ["greedy", "sample", "top-k", "diverse-beam-search", "min-bayes-risk"] |
| ) |
| |
| st.session_state.decoding_params = {} |
| st.session_state.decoding_params['max_len'] = st.sidebar.slider('Max length', 1, 500, 50) |
| st.session_state.decoding_params['temperature'] = st.sidebar.slider('Temperature', 0.0, 1.0, 0.1) |
| if st.session_state.decoding_mode == "top-k": |
| st.session_state.decoding_params["k"] = st.sidebar.slider("k value", 1, 100, 5) |
| elif st.session_state.decoding_mode == "diverse-beam-search": |
| st.session_state.decoding_params["beam_width"] = st.sidebar.slider("beam width", 1, 10, 1) |
| st.session_state.decoding_params["diversity_penalty"] = st.sidebar.slider("diversity penalty", 0.0, 1.0, 0.1) |
| elif st.session_state.decoding_mode == "min-bayes-risk": |
| st.session_state.decoding_params["num_candidates"] = st.sidebar.slider("Number of candidates", 1, 30, 4) |
| |
| if tab_selection == "Random Sample from Test Set": |
| st.header("Generate from Test Dataset") |
| |
| col1, col2 = st.columns([3, 1]) |
| |
| with col1: |
| |
| st.write(f"Test dataset contains 5000 samples") |
| |
| with col2: |
| |
| if st.button("Generate Random Sample"): |
| random_idx = np.random.randint(1, 5000) |
| st.session_state.random_idx = random_idx |
| st.session_state.ip, st.session_state.ip_mask, st.session_state.tg, st.session_state.tg_mask = get_sample(random_idx) |
|
|
| |
| if hasattr(st.session_state, 'random_idx'): |
| st.subheader(f"Sample #{st.session_state.random_idx}") |
| st.session_state.x = tokenizer.decode(st.session_state.ip.tolist()[0]) |
| st.session_state.y = tokenizer.decode(st.session_state.tg.tolist()) |
| |
| df = pd.DataFrame.from_dict({'X': [st.session_state.x], 'y': [st.session_state.y]}) |
| st.dataframe(df.T.reset_index(), width=800) |
| |
| |
| if st.button("Generate Sequence"): |
| with st.spinner("Generating sequence..."): |
| print(f'Ip: {st.session_state.ip} | Mask: {st.session_state.ip_mask} \n Mode: {st.session_state.decoding_mode} | Params: {st.session_state.decoding_params}') |
| st.session_state.tok_output = genOp( |
| encoder, decoder, device, |
| st.session_state.ip, |
| st.session_state.ip_mask, |
| mode=st.session_state.decoding_mode, |
| **st.session_state.decoding_params |
| ) |
| print(f'\n\n\nOutput: {st.session_state.tok_output} \n') |
| st.session_state.output = tokenizer.decode(st.session_state.tok_output) |
|
|
| |
| if hasattr(st.session_state, 'output'): |
| st.subheader("Generated Sequence") |
| st.write(st.session_state.output) |
|
|
| if __name__ == "__main__": |
| main() |
| 1 |