| import streamlit as st |
| import torch |
| from normflows import nflow |
| import numpy as np |
| import seaborn as sns |
| import pandas as pd |
|
|
| uploaded_file = st.file_uploader("Choose original dataset") |
| col1,col2,col3 = st.columns(3) |
| bw = col1.number_input('Scale',value=3.05) |
| wd = col2.number_input('Weight Decay',value=0.0002) |
| iters = col3.number_input('Iterations',value=400) |
|
|
|
|
|
|
| def compute(dim): |
| api = nflow(dim=dim,latent=16,dataset=uploaded_file) |
| api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=wd) |
| |
| my_bar = st.progress(0) |
| |
| for idx in api.train(iters=iters): |
| my_bar.progress(idx[0]/iters) |
| my_bar.progress(100) |
| samples = np.delete(np.array(api.model.sample(torch.tensor(api.scaled).float()).detach()),np.argmin(np.array(api.model.sample(torch.tensor(api.scaled).float()).detach()),axis=0),0) |
| |
| |
| |
|
|
| |
| g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=1000) |
|
|
| w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real') |
| st.pyplot(w.get_figure()) |
| |
| |
| def random_normal_samples(n, dim=3): |
| return torch.zeros(n, dim).normal_(mean=0, std=1) |
| |
| samples = np.array(api.model.sample(torch.tensor(random_normal_samples(1000,api.scaled.shape[-1])).float()).detach()) |
| |
| return api.scaler.inverse_transform(samples) |
|
|
| with st.form('login_form'): |
| st.write('Token for generation:') |
| token = st.text_input('Token') |
| submit = st.form_submit_button('Submit') |
|
|
| if token in st.secrets['tokens'] and submit: |
| |
| if uploaded_file is not None: |
| dims = len(uploaded_file.getvalue().decode("utf-8").split('\n')[0].split(','))-1 |
| samples=compute(dims) |
| st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv') |
| |
| elif not uploaded_file: |
| st.write('Upload your file') |
| |
| else: |
| st.markdown('## :red[You dont have access]') |
| st.markdown('Buy tokens here: [@advprop](https://adprop.t.me)') |