| import os |
| import shutil |
| import time |
| import csv |
| import uuid |
| from itertools import cycle |
| from typing import List, Tuple, Optional |
| from datetime import datetime |
| import gradio as gr |
|
|
| from .data_fetcher import read_hacker_news_rss, format_published_time |
| from .model_trainer import ( |
| authenticate_hf, |
| train_with_dataset, |
| get_top_hits, |
| load_embedding_model, |
| upload_model_to_hub |
| ) |
| from .config import AppConfig |
| from .vibe_logic import VibeChecker |
| from sentence_transformers import SentenceTransformer |
|
|
| class HackerNewsFineTuner: |
| """ |
| Encapsulates all application logic and state for a single user session. |
| """ |
|
|
| def __init__(self, config: AppConfig = AppConfig): |
| |
| self.config = config |
| |
| |
| self.session_id = str(uuid.uuid4()) |
| |
| |
| self.session_root = self.config.ARTIFACTS_DIR / self.session_id |
| self.output_dir = self.session_root / "embedding_gemma_finetuned" |
| self.dataset_export_file = self.session_root / "training_dataset.csv" |
| |
| |
| os.makedirs(self.output_dir, exist_ok=True) |
| print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}") |
|
|
| |
| self.model: Optional[SentenceTransformer] = None |
| self.vibe_checker: Optional[VibeChecker] = None |
| self.titles: List[str] = [] |
| self.last_hn_dataset: List[List[str]] = [] |
| self.imported_dataset: List[List[str]] = [] |
|
|
| |
| authenticate_hf(self.config.HF_TOKEN) |
|
|
| def _update_vibe_checker(self): |
| """Initializes or updates the VibeChecker with the current model state.""" |
| if self.model: |
| self.vibe_checker = VibeChecker( |
| model=self.model, |
| query_anchor=self.config.QUERY_ANCHOR, |
| task_name=self.config.TASK_NAME |
| ) |
| else: |
| self.vibe_checker = None |
|
|
| |
|
|
| def refresh_data_and_model(self) -> Tuple[List[str], str]: |
| """ |
| Reloads model and fetches data. |
| Returns: |
| - List of titles (for the UI) |
| - Status message string |
| """ |
| print(f"[{self.session_id}] Reloading model and data...") |
|
|
| self.last_hn_dataset = [] |
| self.imported_dataset = [] |
|
|
| |
| try: |
| self.model = load_embedding_model(self.config.MODEL_NAME) |
| self._update_vibe_checker() |
| except Exception as e: |
| error_msg = f"CRITICAL ERROR: Model failed to load. {e}" |
| print(error_msg) |
| self.model = None |
| self._update_vibe_checker() |
| return [], error_msg |
|
|
| |
| news_feed, status_msg = read_hacker_news_rss(self.config) |
| titles_out = [] |
| status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}" |
|
|
| if news_feed is not None and news_feed.entries: |
| titles_out = [item.title for item in news_feed.entries] |
| else: |
| titles_out = ["Error fetching news."] |
| gr.Warning(f"Data reload failed. {status_msg}") |
|
|
| self.titles = titles_out |
|
|
| |
| return self.titles, status_value |
|
|
| |
| def import_additional_dataset(self, file_path: str) -> str: |
| if not file_path: |
| return "Please upload a CSV file." |
| new_dataset, num_imported = [], 0 |
| try: |
| with open(file_path, 'r', newline='', encoding='utf-8') as f: |
| reader = csv.reader(f) |
| try: |
| header = next(reader) |
| |
| if not (header and header[0].lower().strip() == 'anchor'): |
| f.seek(0) |
| except StopIteration: |
| return "Error: Uploaded file is empty." |
|
|
| for row in reader: |
| if len(row) == 3: |
| new_dataset.append([s.strip() for s in row]) |
| num_imported += 1 |
| if num_imported == 0: |
| raise ValueError("No valid rows found.") |
| self.imported_dataset = new_dataset |
| return f"Imported {num_imported} triplets." |
| except Exception as e: |
| return f"Import failed: {e}" |
|
|
| def export_dataset(self) -> Optional[str]: |
| if not self.last_hn_dataset: |
| gr.Warning("No dataset generated yet.") |
| return None |
| |
| file_path = self.dataset_export_file |
| try: |
| with open(file_path, 'w', newline='', encoding='utf-8') as f: |
| writer = csv.writer(f) |
| writer.writerow(['Anchor', 'Positive', 'Negative']) |
| writer.writerows(self.last_hn_dataset) |
| gr.Info(f"Dataset exported.") |
| return str(file_path) |
| except Exception as e: |
| gr.Error(f"Export failed: {e}") |
| return None |
|
|
| def download_model(self) -> Optional[str]: |
| if not os.path.exists(self.output_dir): |
| gr.Warning("No model trained yet.") |
| return None |
| |
| timestamp = int(time.time()) |
| try: |
| base_name = self.session_root / f"model_finetuned_{timestamp}" |
| archive_path = shutil.make_archive( |
| base_name=str(base_name), |
| format='zip', |
| root_dir=self.output_dir, |
| ) |
| gr.Info(f"Model zipped.") |
| return archive_path |
| except Exception as e: |
| gr.Error(f"Zip failed: {e}") |
| return None |
|
|
| def upload_model(self, repo_name: str, oauth_token_str: str) -> str: |
| """ |
| Calls the model trainer upload function using the session's output directory. |
| """ |
| if not os.path.exists(self.output_dir): |
| return "❌ Error: No trained model found in this session. Run training first." |
| if not repo_name.strip(): |
| return "❌ Error: Please specify a repository name." |
| |
| return upload_model_to_hub(self.output_dir, repo_name, oauth_token_str) |
|
|
|
|
| |
| def _create_hn_dataset(self, pos_ids: List[int], neg_ids: List[int]) -> List[List[str]]: |
| """ |
| Creates triplets (Anchor, Positive, Negative) from the selected indices. |
| Uses cycling to balance the dataset if the number of positives != negatives. |
| """ |
| if not pos_ids or not neg_ids: |
| return [] |
|
|
| |
| pos_titles = [self.titles[i] for i in pos_ids] |
| neg_titles = [self.titles[i] for i in neg_ids] |
|
|
| dataset = [] |
|
|
| |
| |
| |
| |
| if len(pos_titles) >= len(neg_titles): |
| |
| neg_cycle = cycle(neg_titles) |
| for p_title in pos_titles: |
| dataset.append([self.config.QUERY_ANCHOR, p_title, next(neg_cycle)]) |
| else: |
| |
| pos_cycle = cycle(pos_titles) |
| for n_title in neg_titles: |
| dataset.append([self.config.QUERY_ANCHOR, next(pos_cycle), n_title]) |
|
|
| return dataset |
|
|
| def training(self, pos_ids: List[int], neg_ids: List[int]) -> str: |
| """ |
| Main training entry point. |
| Args: |
| pos_ids: Indices of stories marked as "Favorite" |
| neg_ids: Indices of stories marked as "Dislike" |
| """ |
| if self.model is None: |
| raise gr.Error("Model not loaded.") |
| |
| if self.imported_dataset: |
| self.last_hn_dataset = self.imported_dataset |
| else: |
| |
| if not pos_ids: |
| raise gr.Error("Please select at least one 'Favorite' story.") |
| if not neg_ids: |
| raise gr.Error("Please select at least one 'Dislike' story.") |
| |
| |
| self.last_hn_dataset = self._create_hn_dataset(pos_ids, neg_ids) |
| |
| if not self.last_hn_dataset: |
| raise gr.Error("Dataset generation failed (Empty dataset).") |
|
|
| def semantic_search_fn() -> str: |
| return get_top_hits(model=self.model, target_titles=self.titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR) |
|
|
| result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n" |
| print(f"[{self.session_id}] Starting Training with {len(self.last_hn_dataset)} examples...") |
| |
| train_with_dataset( |
| model=self.model, |
| dataset=self.last_hn_dataset, |
| output_dir=self.output_dir, |
| task_name=self.config.TASK_NAME, |
| search_fn=semantic_search_fn |
| ) |
| |
| self._update_vibe_checker() |
| print(f"[{self.session_id}] Training Complete.") |
|
|
| result += "### Search (After):\n" + f"{semantic_search_fn()}" |
| return result |
|
|
| def is_model_tuned(self) -> bool: |
| return True if self.last_hn_dataset else False |
|
|
| |
| def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]: |
| model_name = "<unsaved>" |
| if self.last_hn_dataset: |
| model_name = f"./{self.output_dir}" |
|
|
| info_text = (f"**Session:** {self.session_id[:6]}<br>" |
| f"**Base Model:** `{self.config.MODEL_NAME}`<br>" |
| f"**Tuned Model:** `{model_name}`") |
|
|
| if not self.vibe_checker: |
| return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_css("gray")), info_text |
| if not news_text or len(news_text.split()) < 3: |
| return "N/A", "Text too short", gr.update(value=self._generate_vibe_css("gray")), info_text |
|
|
| try: |
| vibe_result = self.vibe_checker.check(news_text) |
| status = vibe_result.status_html.split('>')[1].split('<')[0] |
| return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_css(vibe_result.color_hsl)), info_text |
| except Exception as e: |
| return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_css("gray")), info_text |
|
|
| def _generate_vibe_css(self, color: str) -> str: |
| """Generates a style block to update the Mood Lamp textbox background.""" |
| return f"<style>#mood_lamp input {{ background-color: {color} !important; transition: background-color 0.5s ease; }}</style>" |
|
|
| |
| def fetch_and_display_mood_feed(self) -> str: |
| if not self.vibe_checker: |
| return "Model not ready. Please wait or reload." |
| |
| feed, status = read_hacker_news_rss(self.config) |
| if not feed or not feed.entries: |
| return f"**Feed Error:** {status}" |
|
|
| scored_entries = [] |
| for entry in feed.entries: |
| title = entry.get('title') |
| if not title: continue |
| |
| vibe_result = self.vibe_checker.check(title) |
| scored_entries.append({ |
| "title": title, |
| "link": entry.get('link', '#'), |
| "comments": entry.get('comments', '#'), |
| "published": format_published_time(entry.published_parsed), |
| "mood": vibe_result |
| }) |
|
|
| scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True) |
|
|
| model_name = "<unsaved>" |
| if self.last_hn_dataset: |
| model_name = f"./{self.output_dir}" |
|
|
| md = (f"## Hacker News Top Stories\n" |
| f"**Session:** {self.session_id[:6]}<br>" |
| f"**Base Model:** `{self.config.MODEL_NAME}`<br>" |
| f"**Tuned Model:** `{model_name}`<br>" |
| f"**Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" |
| "| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n") |
| |
| for item in scored_entries: |
| md += (f"| {item['mood'].status_html} " |
| f"| {item['mood'].raw_score:.4f} " |
| f"| [{item['title']}]({item['link']}) " |
| f"| [Comments]({item['comments']}) " |
| f"| {item['published']} |\n") |
| return md |
|
|