| import streamlit as st |
| import torch |
| import transformers |
| from trl import AutoModelForCausalLMWithValueHead |
| import math |
| import time |
|
|
|
|
| st.set_page_config(page_title="RLHF Magic | Movie Reviews", page_icon="🍿", layout="wide") |
|
|
|
|
| st.markdown(""" |
| <style> |
| .big-font { font-size:22px !important; font-weight: 500; } |
| .stProgress .st-bo { transition: background-color 0.5s ease; } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| st.title("🍿 Нейросеть-Кинокритик: До и После RLHF") |
| st.markdown(""" |
| <div class="big-font"> |
| Посмотрите, как работает магия обучения с подкреплением (RLHF). <br> |
| Слева — базовая модель GPT-2, которая пишет что вздумается. Справа — та же модель, но <b>натренированная всегда писать позитивные отзывы</b>, даже если вы начинаете текст с ужасных слов! |
| </div> |
| <br> |
| """, unsafe_allow_html=True) |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| @st.cache_resource |
| def load_models(): |
| reward_path = "reward_model_trained" |
| ppo_path = "ppo_model_trained" |
| orig_model_name = "lvwerra/gpt2-imdb" |
| |
| |
| reward_tokenizer = transformers.AutoTokenizer.from_pretrained(reward_path) |
| reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(reward_path).to(DEVICE).eval() |
| |
| |
| orig_tokenizer = transformers.AutoTokenizer.from_pretrained(orig_model_name) |
| if orig_tokenizer.pad_token is None: |
| orig_tokenizer.pad_token = orig_tokenizer.eos_token |
| orig_model = transformers.AutoModelForCausalLM.from_pretrained(orig_model_name).to(DEVICE).eval() |
| |
| |
| rlhf_model_full = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_path).to(DEVICE).eval() |
| rlhf_model = rlhf_model_full.pretrained_model |
| |
| return reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model |
|
|
| with st.spinner("⏳ Подготовка нейросетей... (занимает около минуты при первом старте)"): |
| reward_model, reward_tokenizer, orig_model, orig_tokenizer, rlhf_model = load_models() |
|
|
|
|
| def compute_reward(text): |
| inputs = reward_tokenizer(text, truncation=True, max_length=512, padding=True, return_tensors="pt").to(DEVICE) |
| with torch.no_grad(): |
| score = reward_model(**inputs).logits[0, 0].item() |
| return score |
|
|
| def get_positivity_percent(score): |
| return int((1 / (1 + math.exp(-score))) * 100) |
|
|
| def generate_text(model, tokenizer, prompt, max_new_tokens, temperature, top_p): |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
| with torch.no_grad(): |
| outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True, |
| temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id) |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| def stream_text(text, delay=0.03): |
| for word in text.split(" "): |
| yield word + " " |
| time.sleep(delay) |
|
|
|
|
| st.sidebar.image("https://huggingface.co/front/assets/huggingface_logo-noborder.svg", width=50) |
| st.sidebar.header("🎛 Настройки генерации") |
| max_tokens = st.sidebar.slider("Длина продолжения (токенов)", 20, 150, 70) |
| temp = st.sidebar.slider("Креативность (Temperature)", 0.1, 1.5, 0.8) |
| st.sidebar.info("💡 **Попробуйте начать так:**\n\n- *I hate this movie because*\n- *The acting was terrible and*\n- *To be honest, the plot was*") |
|
|
| |
| user_prompt = st.text_input("✍️ Напишите начало отзыва (на англ.) и нажмите Enter:", |
| value="The director tried to make a good movie and", |
| max_chars=100) |
|
|
| if st.button("Мне повезет!", type="primary", use_container_width=True): |
| |
| |
| with st.spinner("GPT goes brrr..."): |
| orig_text = generate_text(orig_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95) |
| orig_reward = compute_reward(orig_text) |
| orig_percent = get_positivity_percent(orig_reward) |
| |
| rlhf_text = generate_text(rlhf_model, orig_tokenizer, user_prompt, max_tokens, temp, 0.95) |
| rlhf_reward = compute_reward(rlhf_text) |
| rlhf_percent = get_positivity_percent(rlhf_reward) |
|
|
| st.markdown("---") |
| |
| |
| col1, col2 = st.columns(2) |
| |
| |
| with col1: |
| with st.container(border=True): |
| st.subheader("До RLHF (Свободная GPT-2)") |
| st.caption("Пишет как попало (может быть негативной)") |
| |
| |
| st.progress(orig_percent / 100, text=f"Уровень позитивности: {orig_percent}%") |
| |
| |
| st.write_stream(stream_text(orig_text)) |
|
|
| |
| with col2: |
| with st.container(border=True): |
| st.subheader("После RLHF (Good Boy Model)") |
| st.caption("Старается вырулить любой текст в позитив") |
| |
| |
| st.progress(rlhf_percent / 100, text=f"Уровень позитивности: {rlhf_percent}%") |
| |
| |
| time.sleep(1) |
| st.write_stream(stream_text(rlhf_text, delay=0.04)) |
| |
| |
| if rlhf_percent > orig_percent + 20 and rlhf_percent > 70: |
| st.balloons() |
| st.toast('🎉 RLHF модель блестяще спасла ситуацию!', icon='😍') |
| elif rlhf_percent < 50: |
| st.toast('Начало было настолько суровым, что даже RLHF сдалась.', icon='💀') |
|
|