Upload folder using huggingface_hub
Browse files- README.md +6 -5
- app.py +4 -4
- cli_mood_reader.py +3 -3
- flask_app.py +2 -2
- src/__init__.py +0 -0
- src/config.py +56 -0
- src/data_fetcher.py +112 -0
- src/hn_mood_reader.py +71 -0
- src/model_trainer.py +162 -0
- src/vibe_logic.py +85 -0
README.md
CHANGED
|
@@ -160,11 +160,12 @@ Key parameters can be adjusted in `config.py`:
|
|
| 160 |
βββ app.py # Main Gradio application for fine-tuning
|
| 161 |
βββ cli_mood_reader.py # Interactive command-line mood reader
|
| 162 |
βββ flask_app.py # Standalone Flask application for mood reading
|
| 163 |
-
βββ
|
| 164 |
-
βββ
|
| 165 |
-
βββ
|
| 166 |
-
βββ
|
| 167 |
-
βββ
|
|
|
|
| 168 |
βββ requirements.txt # Python package dependencies
|
| 169 |
βββ README.md # This file
|
| 170 |
βββ artifacts/ # Stores session-specific fine-tuned models and datasets
|
|
|
|
| 160 |
βββ app.py # Main Gradio application for fine-tuning
|
| 161 |
βββ cli_mood_reader.py # Interactive command-line mood reader
|
| 162 |
βββ flask_app.py # Standalone Flask application for mood reading
|
| 163 |
+
βββ src/ # Source code for the application
|
| 164 |
+
β βββ config.py # Central configuration for all modules
|
| 165 |
+
β βββ data_fetcher.py # Fetches and caches the Hacker News RSS feed
|
| 166 |
+
β βββ hn_mood_reader.py # Core logic for fetching and scoring
|
| 167 |
+
β βββ model_trainer.py # Handles model loading and fine-tuning
|
| 168 |
+
β βββ vibe_logic.py # Calculates similarity scores and "vibe" status
|
| 169 |
βββ requirements.txt # Python package dependencies
|
| 170 |
βββ README.md # This file
|
| 171 |
βββ artifacts/ # Stores session-specific fine-tuned models and datasets
|
app.py
CHANGED
|
@@ -9,16 +9,16 @@ from typing import List, Iterable, Tuple, Optional, Callable
|
|
| 9 |
from datetime import datetime
|
| 10 |
|
| 11 |
# Import modules
|
| 12 |
-
from data_fetcher import read_hacker_news_rss, format_published_time
|
| 13 |
-
from model_trainer import (
|
| 14 |
authenticate_hf,
|
| 15 |
train_with_dataset,
|
| 16 |
get_top_hits,
|
| 17 |
load_embedding_model,
|
| 18 |
upload_model_to_hub
|
| 19 |
)
|
| 20 |
-
from config import AppConfig
|
| 21 |
-
from vibe_logic import VibeChecker
|
| 22 |
from sentence_transformers import SentenceTransformer
|
| 23 |
|
| 24 |
# --- Main Application Class (Session Scoped) ---
|
|
|
|
| 9 |
from datetime import datetime
|
| 10 |
|
| 11 |
# Import modules
|
| 12 |
+
from src.data_fetcher import read_hacker_news_rss, format_published_time
|
| 13 |
+
from src.model_trainer import (
|
| 14 |
authenticate_hf,
|
| 15 |
train_with_dataset,
|
| 16 |
get_top_hits,
|
| 17 |
load_embedding_model,
|
| 18 |
upload_model_to_hub
|
| 19 |
)
|
| 20 |
+
from src.config import AppConfig
|
| 21 |
+
from src.vibe_logic import VibeChecker
|
| 22 |
from sentence_transformers import SentenceTransformer
|
| 23 |
|
| 24 |
# --- Main Application Class (Session Scoped) ---
|
cli_mood_reader.py
CHANGED
|
@@ -7,9 +7,9 @@ from typing import List
|
|
| 7 |
|
| 8 |
# --- Core Logic Imports ---
|
| 9 |
# These modules contain the application's functionality.
|
| 10 |
-
from config import AppConfig
|
| 11 |
-
from hn_mood_reader import HnMoodReader, FeedEntry
|
| 12 |
-
from vibe_logic import VIBE_THRESHOLDS
|
| 13 |
|
| 14 |
# --- Helper Functions ---
|
| 15 |
|
|
|
|
| 7 |
|
| 8 |
# --- Core Logic Imports ---
|
| 9 |
# These modules contain the application's functionality.
|
| 10 |
+
from src.config import AppConfig
|
| 11 |
+
from src.hn_mood_reader import HnMoodReader, FeedEntry
|
| 12 |
+
from src.vibe_logic import VIBE_THRESHOLDS
|
| 13 |
|
| 14 |
# --- Helper Functions ---
|
| 15 |
|
flask_app.py
CHANGED
|
@@ -7,8 +7,8 @@ from typing import Optional
|
|
| 7 |
from flask import Flask, render_template
|
| 8 |
|
| 9 |
# Your existing config and core logic
|
| 10 |
-
from config import AppConfig
|
| 11 |
-
from hn_mood_reader import HnMoodReader, FeedEntry
|
| 12 |
|
| 13 |
# --- Flask App Initialization ---
|
| 14 |
app = Flask(__name__)
|
|
|
|
| 7 |
from flask import Flask, render_template
|
| 8 |
|
| 9 |
# Your existing config and core logic
|
| 10 |
+
from src.config import AppConfig
|
| 11 |
+
from src.hn_mood_reader import HnMoodReader, FeedEntry
|
| 12 |
|
| 13 |
# --- Flask App Initialization ---
|
| 14 |
app = Flask(__name__)
|
src/__init__.py
ADDED
|
File without changes
|
src/config.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Final
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# --- Base Directory Definition ---
|
| 6 |
+
# Use Path for modern, OS-agnostic path handling
|
| 7 |
+
ARTIFACTS_DIR: Final[Path] = Path("artifacts")
|
| 8 |
+
|
| 9 |
+
class AppConfig:
|
| 10 |
+
"""
|
| 11 |
+
Central configuration class for the Hacker News Fine-Tuner application.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
# --- Directory/Environment Configuration ---
|
| 15 |
+
ARTIFACTS_DIR: Final[Path] = ARTIFACTS_DIR
|
| 16 |
+
|
| 17 |
+
# Environment variable for Hugging Face token (used by model_trainer)
|
| 18 |
+
HF_TOKEN: Final[str | None] = os.getenv('HF_TOKEN')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# --- Caching/Data Fetching Configuration ---
|
| 22 |
+
HN_RSS_URL: Final[str] = "https://news.ycombinator.com/rss"
|
| 23 |
+
|
| 24 |
+
# Filename for the pickled cache data (using Path.joinpath)
|
| 25 |
+
CACHE_FILE: Final[Path] = ARTIFACTS_DIR.joinpath("hacker_news_cache.pkl")
|
| 26 |
+
|
| 27 |
+
# Cache duration set to 30 minutes (1800 seconds)
|
| 28 |
+
CACHE_DURATION_SECONDS: Final[int] = 60 * 30
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# --- Model/Training Configuration ---
|
| 32 |
+
|
| 33 |
+
# Name of the pre-trained embedding model
|
| 34 |
+
MODEL_NAME: Final[str] = 'google/embeddinggemma-300M'
|
| 35 |
+
|
| 36 |
+
# Task name for prompting the embedding model (e.g., for instruction tuning)
|
| 37 |
+
TASK_NAME: Final[str] = "Classification"
|
| 38 |
+
|
| 39 |
+
# Output directory for the fine-tuned model
|
| 40 |
+
OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("embedding-gemma-finetuned-hn")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# --- Gradio/App-Specific Configuration ---
|
| 44 |
+
|
| 45 |
+
# Anchor text used for contrastive learning dataset generation
|
| 46 |
+
QUERY_ANCHOR: Final[str] = "MY_FAVORITE_NEWS"
|
| 47 |
+
|
| 48 |
+
# Number of titles shown for user selection in the Gradio interface
|
| 49 |
+
TOP_TITLES_COUNT: Final[int] = 10
|
| 50 |
+
|
| 51 |
+
# Default export path for the dataset CSV
|
| 52 |
+
DATASET_EXPORT_FILENAME: Final[Path] = ARTIFACTS_DIR.joinpath("training_dataset.csv")
|
| 53 |
+
|
| 54 |
+
# Default model for the standalone Mood Reader tab
|
| 55 |
+
DEFAULT_MOOD_READER_MODEL: Final[str] = "bebechien/embedding-gemma-finetuned-hn"
|
| 56 |
+
|
src/data_fetcher.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import feedparser
|
| 2 |
+
import pickle
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Tuple, Any, Optional
|
| 7 |
+
|
| 8 |
+
# Assuming AppConfig is passed in via dependency injection in the refactored main app.
|
| 9 |
+
|
| 10 |
+
def format_published_time(published_parsed: Optional[time.struct_time]) -> str:
|
| 11 |
+
"""Safely converts a feedparser time struct to a formatted string."""
|
| 12 |
+
if published_parsed:
|
| 13 |
+
try:
|
| 14 |
+
dt_obj = datetime.fromtimestamp(time.mktime(published_parsed))
|
| 15 |
+
return dt_obj.strftime('%Y-%m-%d %H:%M')
|
| 16 |
+
except Exception:
|
| 17 |
+
return 'N/A'
|
| 18 |
+
return 'N/A'
|
| 19 |
+
|
| 20 |
+
def load_feed_from_cache(config: Any) -> Tuple[Optional[Any], str]:
|
| 21 |
+
"""Attempts to load a feed object from the cache file if it exists and is not expired."""
|
| 22 |
+
if not os.path.exists(config.CACHE_FILE):
|
| 23 |
+
return None, "Cache file not found."
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Check cache age
|
| 27 |
+
file_age_seconds = time.time() - os.path.getmtime(config.CACHE_FILE)
|
| 28 |
+
|
| 29 |
+
if file_age_seconds > config.CACHE_DURATION_SECONDS:
|
| 30 |
+
# The cache is too old
|
| 31 |
+
return None, f"Cache expired ({file_age_seconds:.0f}s old, limit is {config.CACHE_DURATION_SECONDS}s)."
|
| 32 |
+
|
| 33 |
+
with open(config.CACHE_FILE, 'rb') as f:
|
| 34 |
+
feed = pickle.load(f)
|
| 35 |
+
return feed, f"Loaded successfully from cache (Age: {file_age_seconds:.0f}s)."
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
# If loading fails, treat it as a miss and attempt to clean up
|
| 39 |
+
print(f"Warning: Failed to load cache file. Deleting corrupted cache. Reason: {e}")
|
| 40 |
+
try:
|
| 41 |
+
os.remove(config.CACHE_FILE)
|
| 42 |
+
except OSError:
|
| 43 |
+
pass # Ignore if removal fails
|
| 44 |
+
return None, "Cache file corrupted or invalid. Will re-fetch."
|
| 45 |
+
|
| 46 |
+
def save_feed_to_cache(config: Any, feed: Any) -> None:
|
| 47 |
+
"""Saves the fetched feed object to the cache file."""
|
| 48 |
+
try:
|
| 49 |
+
with open(config.CACHE_FILE, 'wb') as f:
|
| 50 |
+
pickle.dump(feed, f)
|
| 51 |
+
print(f"Successfully saved new feed data to cache: {config.CACHE_FILE}")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Error saving to cache: {e}")
|
| 54 |
+
|
| 55 |
+
def read_hacker_news_rss(config: Any) -> Tuple[Optional[Any], str]:
|
| 56 |
+
"""
|
| 57 |
+
Reads and parses the Hacker News RSS feed, using a cache if available.
|
| 58 |
+
Returns the feedparser object and a status message.
|
| 59 |
+
"""
|
| 60 |
+
url = config.HN_RSS_URL
|
| 61 |
+
print(f"Attempting to fetch and parse RSS feed from: {url}")
|
| 62 |
+
print("-" * 50)
|
| 63 |
+
|
| 64 |
+
# 1. Attempt to load from cache
|
| 65 |
+
feed, cache_status = load_feed_from_cache(config)
|
| 66 |
+
print(f"Cache Status: {cache_status}")
|
| 67 |
+
|
| 68 |
+
# 2. If cache miss or stale, fetch from web
|
| 69 |
+
if feed is None:
|
| 70 |
+
print("Starting network fetch...")
|
| 71 |
+
try:
|
| 72 |
+
# Use feedparser to fetch and parse the feed
|
| 73 |
+
feed = feedparser.parse(url)
|
| 74 |
+
|
| 75 |
+
if feed.status >= 400:
|
| 76 |
+
status_msg = f"Error fetching the feed. HTTP Status: {feed.status}"
|
| 77 |
+
print(status_msg)
|
| 78 |
+
return None, status_msg
|
| 79 |
+
|
| 80 |
+
if feed.bozo:
|
| 81 |
+
# Bozo is set if any error occurred, even non-critical ones.
|
| 82 |
+
print(f"Warning: Failed to fully parse the feed. Reason: {feed.get('bozo_exception')}")
|
| 83 |
+
|
| 84 |
+
# 3. If fetch successful, save new data to cache
|
| 85 |
+
if feed.entries:
|
| 86 |
+
save_feed_to_cache(config, feed)
|
| 87 |
+
status_msg = f"Successfully fetched and cached {len(feed.entries)} entries."
|
| 88 |
+
else:
|
| 89 |
+
status_msg = "Fetch successful, but no entries found in the feed."
|
| 90 |
+
print(status_msg)
|
| 91 |
+
feed = None # Ensure feed is None if no entries
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
status_msg = f"An unexpected error occurred during network processing: {e}"
|
| 95 |
+
print(status_msg)
|
| 96 |
+
return None, status_msg
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
status_msg = cache_status
|
| 100 |
+
|
| 101 |
+
return feed, status_msg
|
| 102 |
+
|
| 103 |
+
# Example usage (not part of the refactored module's purpose but good for testing)
|
| 104 |
+
if __name__ == '__main__':
|
| 105 |
+
from .config import AppConfig
|
| 106 |
+
feed, status = read_hacker_news_rss(AppConfig)
|
| 107 |
+
if feed and feed.entries:
|
| 108 |
+
print(f"\nFetched {len(feed.entries)} entries. Top 3 titles:")
|
| 109 |
+
for entry in feed.entries[:3]:
|
| 110 |
+
print(f"- {entry.title}")
|
| 111 |
+
else:
|
| 112 |
+
print(f"Could not fetch the feed. Status: {status}")
|
src/hn_mood_reader.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# hn_mood_reader.py
|
| 2 |
+
|
| 3 |
+
import feedparser
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import List
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Assuming these are in separate files as in the original structure
|
| 10 |
+
from .config import AppConfig
|
| 11 |
+
from .data_fetcher import format_published_time
|
| 12 |
+
from .vibe_logic import VibeChecker, VibeResult
|
| 13 |
+
|
| 14 |
+
# --- Data Structures ---
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class FeedEntry:
|
| 17 |
+
"""Stores necessary data for a single HN story, including its calculated mood."""
|
| 18 |
+
title: str
|
| 19 |
+
link: str
|
| 20 |
+
comments_link: str
|
| 21 |
+
published_time_str: str
|
| 22 |
+
mood: VibeResult
|
| 23 |
+
|
| 24 |
+
# --- Core Logic Class ---
|
| 25 |
+
class HnMoodReader:
|
| 26 |
+
"""Handles model initialization and mood scoring for Hacker News titles."""
|
| 27 |
+
def __init__(self, model_name: str):
|
| 28 |
+
try:
|
| 29 |
+
from sentence_transformers import SentenceTransformer
|
| 30 |
+
except ImportError as e:
|
| 31 |
+
raise ImportError("Please install 'sentence-transformers'") from e
|
| 32 |
+
|
| 33 |
+
print(f"Initializing SentenceTransformer with model: {model_name}...")
|
| 34 |
+
self.model = SentenceTransformer(model_name, truncate_dim=128)
|
| 35 |
+
print("Model initialized successfully.")
|
| 36 |
+
|
| 37 |
+
self.vibe_checker = VibeChecker(
|
| 38 |
+
model=self.model,
|
| 39 |
+
query_anchor=AppConfig.QUERY_ANCHOR,
|
| 40 |
+
task_name=AppConfig.TASK_NAME
|
| 41 |
+
)
|
| 42 |
+
self.model_name = model_name
|
| 43 |
+
|
| 44 |
+
def _get_mood_result(self, title: str) -> VibeResult:
|
| 45 |
+
"""Calculates the mood for a title using the VibeChecker."""
|
| 46 |
+
return self.vibe_checker.check(title)
|
| 47 |
+
|
| 48 |
+
def fetch_and_score_feed(self) -> List[FeedEntry]:
|
| 49 |
+
"""Fetches, scores, and sorts entries from the HN RSS feed."""
|
| 50 |
+
feed = feedparser.parse(AppConfig.HN_RSS_URL)
|
| 51 |
+
if feed.bozo:
|
| 52 |
+
raise IOError(f"Error parsing feed from {AppConfig.HN_RSS_URL}.")
|
| 53 |
+
|
| 54 |
+
scored_entries: List[FeedEntry] = []
|
| 55 |
+
for entry in feed.entries:
|
| 56 |
+
title, link = entry.get('title'), entry.get('link')
|
| 57 |
+
if not title or not link:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
scored_entries.append(
|
| 61 |
+
FeedEntry(
|
| 62 |
+
title=title,
|
| 63 |
+
link=link,
|
| 64 |
+
comments_link=entry.get('comments', '#'),
|
| 65 |
+
published_time_str=format_published_time(entry.published_parsed),
|
| 66 |
+
mood=self._get_mood_result(title)
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
scored_entries.sort(key=lambda x: x.mood.raw_score, reverse=True)
|
| 71 |
+
return scored_entries
|
src/model_trainer.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import login, HfApi # Updated import
|
| 2 |
+
from sentence_transformers import SentenceTransformer, util
|
| 3 |
+
from datasets import Dataset
|
| 4 |
+
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
| 5 |
+
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
| 6 |
+
from transformers import TrainerCallback, TrainingArguments
|
| 7 |
+
from typing import List, Callable, Optional
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# --- Model/Utility Functions ---
|
| 11 |
+
|
| 12 |
+
def authenticate_hf(token: Optional[str]) -> None:
|
| 13 |
+
"""Logs into the Hugging Face Hub."""
|
| 14 |
+
if token:
|
| 15 |
+
print("Logging into Hugging Face Hub...")
|
| 16 |
+
login(token=token)
|
| 17 |
+
else:
|
| 18 |
+
print("Skipping Hugging Face login: HF_TOKEN not set.")
|
| 19 |
+
|
| 20 |
+
def load_embedding_model(model_name: str) -> SentenceTransformer:
|
| 21 |
+
"""Initializes the Sentence Transformer model."""
|
| 22 |
+
print(f"Loading Sentence Transformer model: {model_name}")
|
| 23 |
+
try:
|
| 24 |
+
model = SentenceTransformer(model_name, model_kwargs={"device_map": "auto"})
|
| 25 |
+
print(f"Model loaded successfully. {model.device}")
|
| 26 |
+
return model
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error loading Sentence Transformer model {model_name}: {e}")
|
| 29 |
+
raise
|
| 30 |
+
|
| 31 |
+
def get_top_hits(
|
| 32 |
+
model: SentenceTransformer,
|
| 33 |
+
target_titles: List[str],
|
| 34 |
+
task_name: str,
|
| 35 |
+
query: str = "MY_FAVORITE_NEWS",
|
| 36 |
+
top_k: int = 5
|
| 37 |
+
) -> str:
|
| 38 |
+
"""Performs semantic search on target_titles and returns a formatted result string."""
|
| 39 |
+
if not target_titles:
|
| 40 |
+
return "No target titles available for search."
|
| 41 |
+
|
| 42 |
+
# Encode the query
|
| 43 |
+
query_embedding = model.encode(query, prompt_name=task_name)
|
| 44 |
+
|
| 45 |
+
# Encode the target titles (only done once per call)
|
| 46 |
+
title_embeddings = model.encode(target_titles, prompt_name=task_name)
|
| 47 |
+
|
| 48 |
+
# Perform semantic search
|
| 49 |
+
top_hits = util.semantic_search(query_embedding, title_embeddings, top_k=top_k)[0]
|
| 50 |
+
|
| 51 |
+
result = []
|
| 52 |
+
for hit in top_hits:
|
| 53 |
+
title = target_titles[hit['corpus_id']]
|
| 54 |
+
score = hit['score']
|
| 55 |
+
result.append(f"[{title}] {score:.4f}")
|
| 56 |
+
|
| 57 |
+
return "\n".join(result)
|
| 58 |
+
|
| 59 |
+
def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
|
| 60 |
+
"""
|
| 61 |
+
Uploads a local model folder to the Hugging Face Hub.
|
| 62 |
+
Creates the repository if it doesn't exist.
|
| 63 |
+
"""
|
| 64 |
+
try:
|
| 65 |
+
api = HfApi(token=token)
|
| 66 |
+
|
| 67 |
+
# Get the authenticated user's username
|
| 68 |
+
user_info = api.whoami()
|
| 69 |
+
username = user_info['name']
|
| 70 |
+
|
| 71 |
+
# Construct the full repo ID
|
| 72 |
+
repo_id = f"{username}/{repo_name}"
|
| 73 |
+
print(f"Preparing to upload to: {repo_id}")
|
| 74 |
+
|
| 75 |
+
# Create the repo (safe if it already exists)
|
| 76 |
+
api.create_repo(repo_id=repo_id, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
# Upload the folder
|
| 79 |
+
url = api.upload_folder(
|
| 80 |
+
folder_path=folder_path,
|
| 81 |
+
repo_id=repo_id,
|
| 82 |
+
repo_type="model"
|
| 83 |
+
)
|
| 84 |
+
return f"β
Success! Model published at: {url}"
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Upload failed: {e}")
|
| 87 |
+
return f"β Upload failed: {str(e)}"
|
| 88 |
+
|
| 89 |
+
# --- Training Class and Function ---
|
| 90 |
+
|
| 91 |
+
class EvaluationCallback(TrainerCallback):
|
| 92 |
+
"""
|
| 93 |
+
A callback that runs the semantic search evaluation at the end of each log step.
|
| 94 |
+
The search function is passed in during initialization.
|
| 95 |
+
"""
|
| 96 |
+
def __init__(self, search_fn: Callable[[], str]):
|
| 97 |
+
self.search_fn = search_fn
|
| 98 |
+
|
| 99 |
+
def on_log(self, args: TrainingArguments, state, control, **kwargs):
|
| 100 |
+
print(f"Step {state.global_step} finished. Running evaluation:")
|
| 101 |
+
print(f"\n{self.search_fn()}\n")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def train_with_dataset(
|
| 105 |
+
model: SentenceTransformer,
|
| 106 |
+
dataset: List[List[str]],
|
| 107 |
+
output_dir: Path,
|
| 108 |
+
task_name: str,
|
| 109 |
+
search_fn: Callable[[], str]
|
| 110 |
+
) -> None:
|
| 111 |
+
"""
|
| 112 |
+
Fine-tunes the provided Sentence Transformer MODEL on the dataset.
|
| 113 |
+
|
| 114 |
+
The dataset should be a list of lists: [[anchor, positive, negative], ...].
|
| 115 |
+
"""
|
| 116 |
+
# Convert to Hugging Face Dataset format
|
| 117 |
+
data_as_dicts = [
|
| 118 |
+
{"anchor": row[0], "positive": row[1], "negative": row[2]}
|
| 119 |
+
for row in dataset
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
train_dataset = Dataset.from_list(data_as_dicts)
|
| 123 |
+
|
| 124 |
+
# Use MultipleNegativesRankingLoss, suitable for contrastive learning
|
| 125 |
+
loss = MultipleNegativesRankingLoss(model)
|
| 126 |
+
|
| 127 |
+
# Note: SentenceTransformer models typically have a 'prompts' attribute
|
| 128 |
+
# which we need to access for the training arguments.
|
| 129 |
+
prompts = getattr(model, 'prompts', {}).get(task_name)
|
| 130 |
+
if not prompts:
|
| 131 |
+
print(f"Warning: Could not find prompts for task '{task_name}' in model. Training may be less effective.")
|
| 132 |
+
# Fallback to an empty list or appropriate default if required by the model's structure
|
| 133 |
+
prompts = []
|
| 134 |
+
|
| 135 |
+
args = SentenceTransformerTrainingArguments(
|
| 136 |
+
output_dir=output_dir,
|
| 137 |
+
prompts=prompts,
|
| 138 |
+
num_train_epochs=4,
|
| 139 |
+
per_device_train_batch_size=1,
|
| 140 |
+
learning_rate=2e-5,
|
| 141 |
+
warmup_ratio=0.1,
|
| 142 |
+
logging_steps=train_dataset.num_rows,
|
| 143 |
+
report_to="none",
|
| 144 |
+
save_strategy="no" # No saving during training, only at the end
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
trainer = SentenceTransformerTrainer(
|
| 148 |
+
model=model,
|
| 149 |
+
args=args,
|
| 150 |
+
train_dataset=train_dataset,
|
| 151 |
+
loss=loss,
|
| 152 |
+
callbacks=[EvaluationCallback(search_fn)]
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
trainer.train()
|
| 156 |
+
|
| 157 |
+
print("Training finished. Model weights are updated in memory.")
|
| 158 |
+
|
| 159 |
+
# Save the final fine-tuned model
|
| 160 |
+
trainer.save_model()
|
| 161 |
+
|
| 162 |
+
print(f"Model saved locally to: {output_dir}")
|
src/vibe_logic.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from math import floor
|
| 3 |
+
from typing import List
|
| 4 |
+
from sentence_transformers import SentenceTransformer, util
|
| 5 |
+
|
| 6 |
+
# --- Data Structures ---
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class VibeThreshold:
|
| 10 |
+
"""Defines a threshold for a Vibe status."""
|
| 11 |
+
score: float
|
| 12 |
+
status: str
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class VibeResult:
|
| 16 |
+
"""Stores the calculated HSL color and status for a given score."""
|
| 17 |
+
raw_score: float
|
| 18 |
+
status_html: str # Pre-formatted HTML for display
|
| 19 |
+
color_hsl: str # Raw HSL color string
|
| 20 |
+
|
| 21 |
+
# Define the status thresholds from highest score to lowest score
|
| 22 |
+
VIBE_THRESHOLDS: List[VibeThreshold] = [
|
| 23 |
+
VibeThreshold(score=0.8, status="β¨ VIBE:HIGH"),
|
| 24 |
+
VibeThreshold(score=0.5, status="π VIBE:GOOD"),
|
| 25 |
+
VibeThreshold(score=0.2, status="π VIBE:FLAT"),
|
| 26 |
+
VibeThreshold(score=0.0, status="π VIBE:LOW "), # Base case for scores < 0.2
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# --- Utility Functions ---
|
| 30 |
+
|
| 31 |
+
def map_score_to_vibe(score: float) -> VibeResult:
|
| 32 |
+
"""
|
| 33 |
+
Maps a cosine similarity score to a VibeResult containing status, HTML, and color.
|
| 34 |
+
"""
|
| 35 |
+
# 1. Clamp score for safety
|
| 36 |
+
clamped_score = max(0.0, min(1.0, score))
|
| 37 |
+
|
| 38 |
+
# 2. Color Calculation
|
| 39 |
+
hue = floor(clamped_score * 120) # Linear interpolation: 0 (Red) -> 120 (Green)
|
| 40 |
+
color_hsl = f"hsl({hue}, 80%, 50%)"
|
| 41 |
+
|
| 42 |
+
# 3. Status Determination
|
| 43 |
+
status_text: str = VIBE_THRESHOLDS[-1].status # Default to the lowest status
|
| 44 |
+
for threshold in VIBE_THRESHOLDS:
|
| 45 |
+
if clamped_score >= threshold.score:
|
| 46 |
+
status_text = threshold.status
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
# 4. Create the pre-formatted HTML for display
|
| 50 |
+
status_html = f"<span style='color: {color_hsl}; font-weight: bold;'>{status_text}</span>"
|
| 51 |
+
|
| 52 |
+
return VibeResult(raw_score=score, status_html=status_html, color_hsl=color_hsl)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# --- Core Logic Class ---
|
| 56 |
+
|
| 57 |
+
class VibeChecker:
|
| 58 |
+
"""
|
| 59 |
+
Handles similarity scoring using a SentenceTransformer model and a pre-set anchor query.
|
| 60 |
+
"""
|
| 61 |
+
def __init__(self, model: SentenceTransformer, query_anchor: str, task_name: str):
|
| 62 |
+
self.model = model
|
| 63 |
+
self.query_anchor = query_anchor
|
| 64 |
+
self.task_name = task_name
|
| 65 |
+
|
| 66 |
+
# Pre-calculate the anchor embedding for efficiency
|
| 67 |
+
self.query_embedding = self.model.encode(
|
| 68 |
+
self.query_anchor,
|
| 69 |
+
prompt_name=self.task_name,
|
| 70 |
+
normalize_embeddings=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def check(self, text: str) -> VibeResult:
|
| 74 |
+
"""
|
| 75 |
+
Calculates the "vibe" of a given text against the pre-configured anchor.
|
| 76 |
+
"""
|
| 77 |
+
title_embedding = self.model.encode(
|
| 78 |
+
text,
|
| 79 |
+
prompt_name=self.task_name,
|
| 80 |
+
normalize_embeddings=True
|
| 81 |
+
)
|
| 82 |
+
# Use dot product for similarity with normalized embeddings
|
| 83 |
+
score: float = util.dot_score(self.query_embedding, title_embedding).item()
|
| 84 |
+
|
| 85 |
+
return map_score_to_vibe(score)
|