| import gradio as gr |
| import jax |
| import jax.numpy as jnp |
| from jax import random |
| from jax.random import PRNGKey |
| import json |
| from globals import Char, State, UserInfo |
| from thompson import ( |
| init_thompson, |
| recommend_characters, |
| update_posterior, |
| compute_reward, |
| construct_feats, |
| ) |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
| class LMCharacterKnowledge: |
| def __init__(self, model_name: str, game_name: str): |
| self.game_name = game_name |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained(model_name) |
| self.prompt = [ |
| { |
| "role": "system", |
| "content": "You are a knowledgeable bastion of fighting game knowledge. Your goal is to answer questions as best as possible about the game you are asked about.", |
| } |
| ] |
| self.cache = {} |
|
|
| def ask_lm(self, prompt, max_tok: int = 4096): |
| try: |
| messages = self.prompt + [{"role": "user", "content": prompt}] |
| inputs = self.tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ) |
| outputs = self.model.generate(**inputs, max_new_tokens=512) |
| result = self.tokenizer.decode( |
| outputs[0][inputs["input_ids"].shape[-1] :], skip_special_tokens=True |
| ) |
| print(result) |
| return result |
| except Exception as e: |
| print(f"Couldn't query{self.model}, error: {e}") |
|
|
| def get_roster(self) -> list[str]: |
| cache_key = f"roster_{self.game_name}" |
| if cache_key in self.cache: |
| return self.cache[cache_key] |
|
|
| roster_prompt = f""" |
| List ALL playable characters in {self.game_name}. Return a structured json array of character names, nothing else at all. |
| Example format is : ["Ryu", "Ken", "Chun Li", "Akuma"] |
| """ |
|
|
| response = self.ask_lm(roster_prompt) |
|
|
| try: |
| start = response.find("[") |
| end = response.find("]") + 1 |
|
|
| if start != -1 and end > start: |
| roster = json.loads(response[start:end]) |
| self.cache[cache_key] = roster |
| return roster |
| except: |
| |
| pass |
|
|
| return ["Ryu", "Ken", "Luke"] |
|
|
| def get_character_data(self, char_name: str) -> dict: |
| cache_key = f"char_{self.game_name}_{char_name}" |
| if cache_key in self.cache: |
| return self.cache[cache_key] |
|
|
| char_data_prompt = f""" |
| for the character {char_name} in the game { |
| self.game_name |
| }, |
| provide some statistics in explicit JSON format: |
| |
| Example format: |
| {{ |
| "difficulty": 0.7, |
| "execution_barrier": 0.6, |
| "neutral_intensity": 0.5, |
| "tier": 0.8, |
| "archetypes": {{ |
| "rushdown": 0.8, |
| "zoner": 0.1, |
| "grappler": 0.0, |
| "all_rounder": 0.1, |
| "setplay": 0.0, |
| "footsies": 0.0 |
| }} |
| }} |
| |
| Replace ALL values with actual numbers for {char_name}. Return ONLY the JSON object, nothing else. |
| """ |
|
|
| response = self.ask_lm(char_data_prompt, max_tok=300) |
| print(f"Raw response for {char_name}: {response}") |
|
|
| try: |
| start = response.find("{") |
| if start == -1: |
| raise ValueError("No opening brace found") |
| |
| brace_count = 0 |
| end = -1 |
| for i in range(start, len(response)): |
| if response[i] == '{': |
| brace_count += 1 |
| elif response[i] == '}': |
| brace_count -= 1 |
| if brace_count == 0: |
| end = i + 1 |
| break |
| |
| if end == -1: |
| raise ValueError("No matching closing brace found") |
| |
| json_str = response[start:end] |
| print(f"Extracted JSON: {json_str}") |
| |
| data = json.loads(json_str) |
| |
| required_keys = ["difficulty", "execution_barrier", "neutral_intensity", "tier", "archetypes"] |
| if not all(key in data for key in required_keys): |
| raise ValueError(f"Missing required keys in parsed data") |
| |
| self.cache[cache_key] = data |
| return data |
| |
| except Exception as e: |
| print(f"Couldn't parse {char_name}'s data: {e}") |
| print(f"Response was: {response[:200]}...") |
|
|
| return { |
| "difficulty": 0.5, |
| "execution_barrier": 0.5, |
| "neutral_intensity": 0.5, |
| "tier": 0.5, |
| "archetypes": { |
| "rushdown": 0.3, |
| "zoner": 0.3, |
| "grappler": 0.1, |
| "all_rounder": 0.2, |
| "setplay": 0.05, |
| "footsies": 0.05, |
| }, |
| } |
|
|
| def build_roster(self) -> tuple[list[Char], list[str]]: |
| roster = self.get_roster() |
| chars = [] |
|
|
| for i, char_name in enumerate(roster): |
| data = self.get_character_data(char_name) |
| archetype_order = [ |
| "rushdown", |
| "zoner", |
| "grappler", |
| "all_rounder", |
| "setplay", |
| "footsies", |
| ] |
| archetype_vec = jnp.array( |
| [data["archetypes"].get(a, 0.0) for a in archetype_order] |
| ) |
|
|
| archetype_vec = archetype_vec / (jnp.sum(archetype_vec) + 1e-8) |
|
|
| char = Char( |
| difficulty=data["difficulty"], |
| archetype_vec=archetype_vec, |
| execution_level=data["execution_barrier"], |
| neutral_required=data["neutral_intensity"], |
| tier=data["tier"], |
| ) |
| chars.append(char) |
|
|
| batched_chars = Char( |
| difficulty=jnp.array([c.difficulty for c in chars]), |
| archetype_vec=jnp.stack([c.archetype_vec for c in chars]), |
| execution_level=jnp.array([c.execution_level for c in chars]), |
| neutral_required=jnp.array([c.neutral_required for c in chars]), |
| tier=jnp.array([c.tier for c in chars]), |
| ) |
|
|
| return batched_chars, roster |
|
|
|
|
| class FGRecommender: |
| def __init__(self): |
| self.lm = None |
| self.chars = None |
| self.roster = None |
| self.state = None |
| self.user = None |
| self.key = PRNGKey(67) |
| self.n_archetypes = 6 |
| self.history = [] |
|
|
| def init_game(self, game_name: str) -> str: |
| if not game_name.strip(): |
| return "please enter name of game" |
|
|
| try: |
| self.lm = LMCharacterKnowledge(model_name="LiquidAI/LFM2-350M", game_name = game_name) |
| self.chars, self.roster = self.lm.build_roster() |
|
|
| n_chars = len(self.roster) |
| feature_dim = 17 |
|
|
| self.state = init_thompson(n_chars, feature_dim) |
|
|
| self.user = UserInfo( |
| skill_level=0.3, |
| games_played=0, |
| chars_attempted_mask=jnp.zeros(n_chars), |
| wr=jnp.ones(n_chars) * 0.5, |
| playtime=jnp.zeros(n_chars), |
| pref_archetype=jnp.zeros(self.n_archetypes), |
| ) |
|
|
| return f"loaded {n_chars} from {game_name}" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
| def get_recs(self, top_k: int = 5) -> tuple[str, str]: |
| if self.state is None: |
| return "please init game" |
|
|
| self.key, subkey = random.split(self.key) |
|
|
| sel, sample_rewards = recommend_characters( |
| subkey, |
| self.state, |
| self.user, |
| self.chars, |
| len(self.roster), |
| top_k=top_k, |
| diversity_threshold=0.75, |
| ) |
|
|
| recommend_text = "## Recommended Chars: \n\n" |
| for i, char_idx in enumerate(sel): |
| char_idx = int(char_idx) |
| if char_idx < 0: |
| continue |
|
|
| char_name = self.roster[char_idx] |
| reward = float(sample_rewards[char_idx]) |
| tried = bool(self.user.chars_attempted_mask[char_idx] > 0.5) |
|
|
| status = "NEW" if not tried else "TRIED" |
|
|
| recommend_text += f"### {i + 1}. {char_name} {status} \n" |
| recommend_text += f"expected_reward: {reward: .4f} \n" |
| recommend_text += f"difficulty: {self.chars.difficulty[char_idx]:.2f}\n" |
| recommend_text += f" Tier: {self.chars.tier[char_idx]:.2f}\n\n" |
|
|
| char_opts = [self.roster[int(idx)] for idx in sel if idx >= 0] |
|
|
| return recommend_text, gr.Dropdown( |
| choices=char_opts, value=char_opts[0] if char_opts else None |
| ) |
|
|
| def record_feedback( |
| self, char_name: str, won: bool, rating: float, playtime: float |
| ) -> str: |
| if self.state is None or char_name is None: |
| return "get recs first" |
|
|
| try: |
| char_idx = self.roster.index(char_name) |
| except ValueError: |
| return f"char {char_name} not found" |
|
|
| sel_char_obj = jax.tree.map(lambda x: x[char_idx], self.chars) |
| feats = construct_feats(self.user, sel_char_obj, char_idx) |
|
|
| reward = compute_reward( |
| won=won, completed=True, rating=rating, playtime_mins=playtime |
| ) |
| self.user = self.user._replace( |
| games_played=self.user.games_played + 1, |
| chars_attempted_mask=self.user.chars_attempted_mask.at[char_idx].set(1), |
| wr=self.user.wr.at[char_idx].set( |
| 0.8 * self.user.wr[char_idx] + 0.2 * float(won) |
| ), |
| playtime=self.user.playtime.at[char_idx].add(playtime), |
| ) |
|
|
| self.history.append( |
| { |
| "character": char_name, |
| "won": won, |
| "rating": rating, |
| "reward": float(reward), |
| } |
| ) |
|
|
| return f"recorded {char_name}'s feedback! Reward was {reward:.4f}" |
|
|
| def get_stats(self) -> str: |
| if self.user is None: |
| return "no stats lol. play some games u scrub" |
|
|
| tried = int(jnp.sum(self.user.chars_attempted_mask)) |
| total = len(self.roster) |
| avg_wr = float(jnp.mean(self.user.wr)) |
|
|
| stats = f"""## Your Stats |
| |
| - **Games played:** {self.user.games_played} |
| - **Characters tried:** {tried}/{total} |
| - **Average win rate:** {avg_wr:.1%} |
| - **Skill level:** {self.user.skill_level:.2f} |
| """ |
| if tried > 0: |
| top_indices = jnp.argsort(-self.user.playtime)[:5] |
| stats += "\n###Most Played:\n" |
| for idx in top_indices: |
| idx = int(idx) |
| playtime = float(self.user.playtime[idx]) |
| if playtime > 0: |
| char_name = self.roster[idx] |
| wr = float(self.user.wr[idx]) |
| stats += f"- **{char_name}**: {playtime:.0f}m, {wr:.1%} WR\n" |
|
|
| return stats |
|
|
| |
| app = FGRecommender() |
|
|
|
|
| def create_ui(): |
| with gr.Blocks( |
| title="Fighting Game Character Recommender", theme=gr.themes.Soft() |
| ) as demo: |
| gr.Markdown("# Fighting Game Character Recommender") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Setup") |
| game_input = gr.Textbox( |
| label="Game Name", |
| placeholder="e.g., Street Fighter 6, Guilty Gear Strive", |
| value="Street Fighter 6", |
| ) |
| init_btn = gr.Button("Initialize Game", variant="primary") |
| init_output = gr.Markdown() |
|
|
| gr.Markdown("### User Profile") |
| skill_slider = gr.Slider(0.0, 1.0, value=0.3, label="Skill Level") |
|
|
| stats_display = gr.Markdown("No stats yet") |
| refresh_stats_btn = gr.Button("Refresh Stats") |
|
|
| with gr.Column(scale=2): |
| gr.Markdown("### Recommendations") |
| top_k_slider = gr.Slider( |
| 1, 5, value=3, step=1, label="Number of Recommendations" |
| ) |
| get_rec_btn = gr.Button("Get Recommendations", variant="primary") |
| rec_output = gr.Markdown() |
|
|
| gr.Markdown("### Record Feedback") |
| with gr.Row(): |
| char_dropdown = gr.Dropdown(label="Character Played", choices=[]) |
| won_checkbox = gr.Checkbox(label="Won?", value=False) |
|
|
| with gr.Row(): |
| rating_slider = gr.Slider( |
| 1, 5, value=3, step=0.5, label="Rating (1-5)" |
| ) |
| playtime_slider = gr.Slider( |
| 5, 60, value=20, step=5, label="Playtime (minutes)" |
| ) |
|
|
| submit_btn = gr.Button("Submit Feedback", variant="secondary") |
| feedback_output = gr.Markdown() |
|
|
| def init_game(game_name): |
| result = app.init_game(game_name) |
| stats = app.get_stats() |
| return result, stats |
|
|
| init_btn.click( |
| init_game, inputs=[game_input], outputs=[init_output, stats_display] |
| ) |
|
|
| get_rec_btn.click( |
| lambda k: app.get_recs(int(k)), |
| inputs=[top_k_slider], |
| outputs=[rec_output, char_dropdown], |
| ) |
|
|
| submit_btn.click( |
| app.record_feedback, |
| inputs=[char_dropdown, won_checkbox, rating_slider, playtime_slider], |
| outputs=[feedback_output], |
| ) |
|
|
| refresh_stats_btn.click(app.get_stats, outputs=[stats_display]) |
|
|
| return demo |
|
|
| |
| if __name__ == "__main__": |
|
|
| demo = create_ui() |
| demo.launch() |
|
|
|
|