bebechien commited on
Commit
beabfb7
Β·
verified Β·
1 Parent(s): c5ec0d9

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -160,11 +160,12 @@ Key parameters can be adjusted in `config.py`:
160
  β”œβ”€β”€ app.py # Main Gradio application for fine-tuning
161
  β”œβ”€β”€ cli_mood_reader.py # Interactive command-line mood reader
162
  β”œβ”€β”€ flask_app.py # Standalone Flask application for mood reading
163
- β”œβ”€β”€ hn_mood_reader.py # Core logic for fetching and scoring (used by Flask/CLI)
164
- β”œβ”€β”€ model_trainer.py # Handles model loading and fine-tuning
165
- β”œβ”€β”€ vibe_logic.py # Calculates similarity scores and "vibe" status
166
- β”œβ”€β”€ data_fetcher.py # Fetches and caches the Hacker News RSS feed
167
- β”œβ”€β”€ config.py # Central configuration for all modules
 
168
  β”œβ”€β”€ requirements.txt # Python package dependencies
169
  β”œβ”€β”€ README.md # This file
170
  β”œβ”€β”€ artifacts/ # Stores session-specific fine-tuned models and datasets
 
160
  β”œβ”€β”€ app.py # Main Gradio application for fine-tuning
161
  β”œβ”€β”€ cli_mood_reader.py # Interactive command-line mood reader
162
  β”œβ”€β”€ flask_app.py # Standalone Flask application for mood reading
163
+ β”œβ”€β”€ src/ # Source code for the application
164
+ β”‚ β”œβ”€β”€ config.py # Central configuration for all modules
165
+ β”‚ β”œβ”€β”€ data_fetcher.py # Fetches and caches the Hacker News RSS feed
166
+ β”‚ β”œβ”€β”€ hn_mood_reader.py # Core logic for fetching and scoring
167
+ β”‚ β”œβ”€β”€ model_trainer.py # Handles model loading and fine-tuning
168
+ β”‚ └── vibe_logic.py # Calculates similarity scores and "vibe" status
169
  β”œβ”€β”€ requirements.txt # Python package dependencies
170
  β”œβ”€β”€ README.md # This file
171
  β”œβ”€β”€ artifacts/ # Stores session-specific fine-tuned models and datasets
app.py CHANGED
@@ -9,16 +9,16 @@ from typing import List, Iterable, Tuple, Optional, Callable
9
  from datetime import datetime
10
 
11
  # Import modules
12
- from data_fetcher import read_hacker_news_rss, format_published_time
13
- from model_trainer import (
14
  authenticate_hf,
15
  train_with_dataset,
16
  get_top_hits,
17
  load_embedding_model,
18
  upload_model_to_hub
19
  )
20
- from config import AppConfig
21
- from vibe_logic import VibeChecker
22
  from sentence_transformers import SentenceTransformer
23
 
24
  # --- Main Application Class (Session Scoped) ---
 
9
  from datetime import datetime
10
 
11
  # Import modules
12
+ from src.data_fetcher import read_hacker_news_rss, format_published_time
13
+ from src.model_trainer import (
14
  authenticate_hf,
15
  train_with_dataset,
16
  get_top_hits,
17
  load_embedding_model,
18
  upload_model_to_hub
19
  )
20
+ from src.config import AppConfig
21
+ from src.vibe_logic import VibeChecker
22
  from sentence_transformers import SentenceTransformer
23
 
24
  # --- Main Application Class (Session Scoped) ---
cli_mood_reader.py CHANGED
@@ -7,9 +7,9 @@ from typing import List
7
 
8
  # --- Core Logic Imports ---
9
  # These modules contain the application's functionality.
10
- from config import AppConfig
11
- from hn_mood_reader import HnMoodReader, FeedEntry
12
- from vibe_logic import VIBE_THRESHOLDS
13
 
14
  # --- Helper Functions ---
15
 
 
7
 
8
  # --- Core Logic Imports ---
9
  # These modules contain the application's functionality.
10
+ from src.config import AppConfig
11
+ from src.hn_mood_reader import HnMoodReader, FeedEntry
12
+ from src.vibe_logic import VIBE_THRESHOLDS
13
 
14
  # --- Helper Functions ---
15
 
flask_app.py CHANGED
@@ -7,8 +7,8 @@ from typing import Optional
7
  from flask import Flask, render_template
8
 
9
  # Your existing config and core logic
10
- from config import AppConfig
11
- from hn_mood_reader import HnMoodReader, FeedEntry
12
 
13
  # --- Flask App Initialization ---
14
  app = Flask(__name__)
 
7
  from flask import Flask, render_template
8
 
9
  # Your existing config and core logic
10
+ from src.config import AppConfig
11
+ from src.hn_mood_reader import HnMoodReader, FeedEntry
12
 
13
  # --- Flask App Initialization ---
14
  app = Flask(__name__)
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Final
3
+ from pathlib import Path
4
+
5
+ # --- Base Directory Definition ---
6
+ # Use Path for modern, OS-agnostic path handling
7
+ ARTIFACTS_DIR: Final[Path] = Path("artifacts")
8
+
9
+ class AppConfig:
10
+ """
11
+ Central configuration class for the Hacker News Fine-Tuner application.
12
+ """
13
+
14
+ # --- Directory/Environment Configuration ---
15
+ ARTIFACTS_DIR: Final[Path] = ARTIFACTS_DIR
16
+
17
+ # Environment variable for Hugging Face token (used by model_trainer)
18
+ HF_TOKEN: Final[str | None] = os.getenv('HF_TOKEN')
19
+
20
+
21
+ # --- Caching/Data Fetching Configuration ---
22
+ HN_RSS_URL: Final[str] = "https://news.ycombinator.com/rss"
23
+
24
+ # Filename for the pickled cache data (using Path.joinpath)
25
+ CACHE_FILE: Final[Path] = ARTIFACTS_DIR.joinpath("hacker_news_cache.pkl")
26
+
27
+ # Cache duration set to 30 minutes (1800 seconds)
28
+ CACHE_DURATION_SECONDS: Final[int] = 60 * 30
29
+
30
+
31
+ # --- Model/Training Configuration ---
32
+
33
+ # Name of the pre-trained embedding model
34
+ MODEL_NAME: Final[str] = 'google/embeddinggemma-300M'
35
+
36
+ # Task name for prompting the embedding model (e.g., for instruction tuning)
37
+ TASK_NAME: Final[str] = "Classification"
38
+
39
+ # Output directory for the fine-tuned model
40
+ OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("embedding-gemma-finetuned-hn")
41
+
42
+
43
+ # --- Gradio/App-Specific Configuration ---
44
+
45
+ # Anchor text used for contrastive learning dataset generation
46
+ QUERY_ANCHOR: Final[str] = "MY_FAVORITE_NEWS"
47
+
48
+ # Number of titles shown for user selection in the Gradio interface
49
+ TOP_TITLES_COUNT: Final[int] = 10
50
+
51
+ # Default export path for the dataset CSV
52
+ DATASET_EXPORT_FILENAME: Final[Path] = ARTIFACTS_DIR.joinpath("training_dataset.csv")
53
+
54
+ # Default model for the standalone Mood Reader tab
55
+ DEFAULT_MOOD_READER_MODEL: Final[str] = "bebechien/embedding-gemma-finetuned-hn"
56
+
src/data_fetcher.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import feedparser
2
+ import pickle
3
+ import os
4
+ import time
5
+ from datetime import datetime
6
+ from typing import Tuple, Any, Optional
7
+
8
+ # Assuming AppConfig is passed in via dependency injection in the refactored main app.
9
+
10
+ def format_published_time(published_parsed: Optional[time.struct_time]) -> str:
11
+ """Safely converts a feedparser time struct to a formatted string."""
12
+ if published_parsed:
13
+ try:
14
+ dt_obj = datetime.fromtimestamp(time.mktime(published_parsed))
15
+ return dt_obj.strftime('%Y-%m-%d %H:%M')
16
+ except Exception:
17
+ return 'N/A'
18
+ return 'N/A'
19
+
20
+ def load_feed_from_cache(config: Any) -> Tuple[Optional[Any], str]:
21
+ """Attempts to load a feed object from the cache file if it exists and is not expired."""
22
+ if not os.path.exists(config.CACHE_FILE):
23
+ return None, "Cache file not found."
24
+
25
+ try:
26
+ # Check cache age
27
+ file_age_seconds = time.time() - os.path.getmtime(config.CACHE_FILE)
28
+
29
+ if file_age_seconds > config.CACHE_DURATION_SECONDS:
30
+ # The cache is too old
31
+ return None, f"Cache expired ({file_age_seconds:.0f}s old, limit is {config.CACHE_DURATION_SECONDS}s)."
32
+
33
+ with open(config.CACHE_FILE, 'rb') as f:
34
+ feed = pickle.load(f)
35
+ return feed, f"Loaded successfully from cache (Age: {file_age_seconds:.0f}s)."
36
+
37
+ except Exception as e:
38
+ # If loading fails, treat it as a miss and attempt to clean up
39
+ print(f"Warning: Failed to load cache file. Deleting corrupted cache. Reason: {e}")
40
+ try:
41
+ os.remove(config.CACHE_FILE)
42
+ except OSError:
43
+ pass # Ignore if removal fails
44
+ return None, "Cache file corrupted or invalid. Will re-fetch."
45
+
46
+ def save_feed_to_cache(config: Any, feed: Any) -> None:
47
+ """Saves the fetched feed object to the cache file."""
48
+ try:
49
+ with open(config.CACHE_FILE, 'wb') as f:
50
+ pickle.dump(feed, f)
51
+ print(f"Successfully saved new feed data to cache: {config.CACHE_FILE}")
52
+ except Exception as e:
53
+ print(f"Error saving to cache: {e}")
54
+
55
+ def read_hacker_news_rss(config: Any) -> Tuple[Optional[Any], str]:
56
+ """
57
+ Reads and parses the Hacker News RSS feed, using a cache if available.
58
+ Returns the feedparser object and a status message.
59
+ """
60
+ url = config.HN_RSS_URL
61
+ print(f"Attempting to fetch and parse RSS feed from: {url}")
62
+ print("-" * 50)
63
+
64
+ # 1. Attempt to load from cache
65
+ feed, cache_status = load_feed_from_cache(config)
66
+ print(f"Cache Status: {cache_status}")
67
+
68
+ # 2. If cache miss or stale, fetch from web
69
+ if feed is None:
70
+ print("Starting network fetch...")
71
+ try:
72
+ # Use feedparser to fetch and parse the feed
73
+ feed = feedparser.parse(url)
74
+
75
+ if feed.status >= 400:
76
+ status_msg = f"Error fetching the feed. HTTP Status: {feed.status}"
77
+ print(status_msg)
78
+ return None, status_msg
79
+
80
+ if feed.bozo:
81
+ # Bozo is set if any error occurred, even non-critical ones.
82
+ print(f"Warning: Failed to fully parse the feed. Reason: {feed.get('bozo_exception')}")
83
+
84
+ # 3. If fetch successful, save new data to cache
85
+ if feed.entries:
86
+ save_feed_to_cache(config, feed)
87
+ status_msg = f"Successfully fetched and cached {len(feed.entries)} entries."
88
+ else:
89
+ status_msg = "Fetch successful, but no entries found in the feed."
90
+ print(status_msg)
91
+ feed = None # Ensure feed is None if no entries
92
+
93
+ except Exception as e:
94
+ status_msg = f"An unexpected error occurred during network processing: {e}"
95
+ print(status_msg)
96
+ return None, status_msg
97
+
98
+ else:
99
+ status_msg = cache_status
100
+
101
+ return feed, status_msg
102
+
103
+ # Example usage (not part of the refactored module's purpose but good for testing)
104
+ if __name__ == '__main__':
105
+ from .config import AppConfig
106
+ feed, status = read_hacker_news_rss(AppConfig)
107
+ if feed and feed.entries:
108
+ print(f"\nFetched {len(feed.entries)} entries. Top 3 titles:")
109
+ for entry in feed.entries[:3]:
110
+ print(f"- {entry.title}")
111
+ else:
112
+ print(f"Could not fetch the feed. Status: {status}")
src/hn_mood_reader.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # hn_mood_reader.py
2
+
3
+ import feedparser
4
+ from datetime import datetime
5
+ from dataclasses import dataclass
6
+ from typing import List
7
+ import os
8
+
9
+ # Assuming these are in separate files as in the original structure
10
+ from .config import AppConfig
11
+ from .data_fetcher import format_published_time
12
+ from .vibe_logic import VibeChecker, VibeResult
13
+
14
+ # --- Data Structures ---
15
+ @dataclass(frozen=True)
16
+ class FeedEntry:
17
+ """Stores necessary data for a single HN story, including its calculated mood."""
18
+ title: str
19
+ link: str
20
+ comments_link: str
21
+ published_time_str: str
22
+ mood: VibeResult
23
+
24
+ # --- Core Logic Class ---
25
+ class HnMoodReader:
26
+ """Handles model initialization and mood scoring for Hacker News titles."""
27
+ def __init__(self, model_name: str):
28
+ try:
29
+ from sentence_transformers import SentenceTransformer
30
+ except ImportError as e:
31
+ raise ImportError("Please install 'sentence-transformers'") from e
32
+
33
+ print(f"Initializing SentenceTransformer with model: {model_name}...")
34
+ self.model = SentenceTransformer(model_name, truncate_dim=128)
35
+ print("Model initialized successfully.")
36
+
37
+ self.vibe_checker = VibeChecker(
38
+ model=self.model,
39
+ query_anchor=AppConfig.QUERY_ANCHOR,
40
+ task_name=AppConfig.TASK_NAME
41
+ )
42
+ self.model_name = model_name
43
+
44
+ def _get_mood_result(self, title: str) -> VibeResult:
45
+ """Calculates the mood for a title using the VibeChecker."""
46
+ return self.vibe_checker.check(title)
47
+
48
+ def fetch_and_score_feed(self) -> List[FeedEntry]:
49
+ """Fetches, scores, and sorts entries from the HN RSS feed."""
50
+ feed = feedparser.parse(AppConfig.HN_RSS_URL)
51
+ if feed.bozo:
52
+ raise IOError(f"Error parsing feed from {AppConfig.HN_RSS_URL}.")
53
+
54
+ scored_entries: List[FeedEntry] = []
55
+ for entry in feed.entries:
56
+ title, link = entry.get('title'), entry.get('link')
57
+ if not title or not link:
58
+ continue
59
+
60
+ scored_entries.append(
61
+ FeedEntry(
62
+ title=title,
63
+ link=link,
64
+ comments_link=entry.get('comments', '#'),
65
+ published_time_str=format_published_time(entry.published_parsed),
66
+ mood=self._get_mood_result(title)
67
+ )
68
+ )
69
+
70
+ scored_entries.sort(key=lambda x: x.mood.raw_score, reverse=True)
71
+ return scored_entries
src/model_trainer.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import login, HfApi # Updated import
2
+ from sentence_transformers import SentenceTransformer, util
3
+ from datasets import Dataset
4
+ from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
5
+ from sentence_transformers.losses import MultipleNegativesRankingLoss
6
+ from transformers import TrainerCallback, TrainingArguments
7
+ from typing import List, Callable, Optional
8
+ from pathlib import Path
9
+
10
+ # --- Model/Utility Functions ---
11
+
12
+ def authenticate_hf(token: Optional[str]) -> None:
13
+ """Logs into the Hugging Face Hub."""
14
+ if token:
15
+ print("Logging into Hugging Face Hub...")
16
+ login(token=token)
17
+ else:
18
+ print("Skipping Hugging Face login: HF_TOKEN not set.")
19
+
20
+ def load_embedding_model(model_name: str) -> SentenceTransformer:
21
+ """Initializes the Sentence Transformer model."""
22
+ print(f"Loading Sentence Transformer model: {model_name}")
23
+ try:
24
+ model = SentenceTransformer(model_name, model_kwargs={"device_map": "auto"})
25
+ print(f"Model loaded successfully. {model.device}")
26
+ return model
27
+ except Exception as e:
28
+ print(f"Error loading Sentence Transformer model {model_name}: {e}")
29
+ raise
30
+
31
+ def get_top_hits(
32
+ model: SentenceTransformer,
33
+ target_titles: List[str],
34
+ task_name: str,
35
+ query: str = "MY_FAVORITE_NEWS",
36
+ top_k: int = 5
37
+ ) -> str:
38
+ """Performs semantic search on target_titles and returns a formatted result string."""
39
+ if not target_titles:
40
+ return "No target titles available for search."
41
+
42
+ # Encode the query
43
+ query_embedding = model.encode(query, prompt_name=task_name)
44
+
45
+ # Encode the target titles (only done once per call)
46
+ title_embeddings = model.encode(target_titles, prompt_name=task_name)
47
+
48
+ # Perform semantic search
49
+ top_hits = util.semantic_search(query_embedding, title_embeddings, top_k=top_k)[0]
50
+
51
+ result = []
52
+ for hit in top_hits:
53
+ title = target_titles[hit['corpus_id']]
54
+ score = hit['score']
55
+ result.append(f"[{title}] {score:.4f}")
56
+
57
+ return "\n".join(result)
58
+
59
+ def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
60
+ """
61
+ Uploads a local model folder to the Hugging Face Hub.
62
+ Creates the repository if it doesn't exist.
63
+ """
64
+ try:
65
+ api = HfApi(token=token)
66
+
67
+ # Get the authenticated user's username
68
+ user_info = api.whoami()
69
+ username = user_info['name']
70
+
71
+ # Construct the full repo ID
72
+ repo_id = f"{username}/{repo_name}"
73
+ print(f"Preparing to upload to: {repo_id}")
74
+
75
+ # Create the repo (safe if it already exists)
76
+ api.create_repo(repo_id=repo_id, exist_ok=True)
77
+
78
+ # Upload the folder
79
+ url = api.upload_folder(
80
+ folder_path=folder_path,
81
+ repo_id=repo_id,
82
+ repo_type="model"
83
+ )
84
+ return f"βœ… Success! Model published at: {url}"
85
+ except Exception as e:
86
+ print(f"Upload failed: {e}")
87
+ return f"❌ Upload failed: {str(e)}"
88
+
89
+ # --- Training Class and Function ---
90
+
91
+ class EvaluationCallback(TrainerCallback):
92
+ """
93
+ A callback that runs the semantic search evaluation at the end of each log step.
94
+ The search function is passed in during initialization.
95
+ """
96
+ def __init__(self, search_fn: Callable[[], str]):
97
+ self.search_fn = search_fn
98
+
99
+ def on_log(self, args: TrainingArguments, state, control, **kwargs):
100
+ print(f"Step {state.global_step} finished. Running evaluation:")
101
+ print(f"\n{self.search_fn()}\n")
102
+
103
+
104
+ def train_with_dataset(
105
+ model: SentenceTransformer,
106
+ dataset: List[List[str]],
107
+ output_dir: Path,
108
+ task_name: str,
109
+ search_fn: Callable[[], str]
110
+ ) -> None:
111
+ """
112
+ Fine-tunes the provided Sentence Transformer MODEL on the dataset.
113
+
114
+ The dataset should be a list of lists: [[anchor, positive, negative], ...].
115
+ """
116
+ # Convert to Hugging Face Dataset format
117
+ data_as_dicts = [
118
+ {"anchor": row[0], "positive": row[1], "negative": row[2]}
119
+ for row in dataset
120
+ ]
121
+
122
+ train_dataset = Dataset.from_list(data_as_dicts)
123
+
124
+ # Use MultipleNegativesRankingLoss, suitable for contrastive learning
125
+ loss = MultipleNegativesRankingLoss(model)
126
+
127
+ # Note: SentenceTransformer models typically have a 'prompts' attribute
128
+ # which we need to access for the training arguments.
129
+ prompts = getattr(model, 'prompts', {}).get(task_name)
130
+ if not prompts:
131
+ print(f"Warning: Could not find prompts for task '{task_name}' in model. Training may be less effective.")
132
+ # Fallback to an empty list or appropriate default if required by the model's structure
133
+ prompts = []
134
+
135
+ args = SentenceTransformerTrainingArguments(
136
+ output_dir=output_dir,
137
+ prompts=prompts,
138
+ num_train_epochs=4,
139
+ per_device_train_batch_size=1,
140
+ learning_rate=2e-5,
141
+ warmup_ratio=0.1,
142
+ logging_steps=train_dataset.num_rows,
143
+ report_to="none",
144
+ save_strategy="no" # No saving during training, only at the end
145
+ )
146
+
147
+ trainer = SentenceTransformerTrainer(
148
+ model=model,
149
+ args=args,
150
+ train_dataset=train_dataset,
151
+ loss=loss,
152
+ callbacks=[EvaluationCallback(search_fn)]
153
+ )
154
+
155
+ trainer.train()
156
+
157
+ print("Training finished. Model weights are updated in memory.")
158
+
159
+ # Save the final fine-tuned model
160
+ trainer.save_model()
161
+
162
+ print(f"Model saved locally to: {output_dir}")
src/vibe_logic.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from math import floor
3
+ from typing import List
4
+ from sentence_transformers import SentenceTransformer, util
5
+
6
+ # --- Data Structures ---
7
+
8
+ @dataclass(frozen=True)
9
+ class VibeThreshold:
10
+ """Defines a threshold for a Vibe status."""
11
+ score: float
12
+ status: str
13
+
14
+ @dataclass(frozen=True)
15
+ class VibeResult:
16
+ """Stores the calculated HSL color and status for a given score."""
17
+ raw_score: float
18
+ status_html: str # Pre-formatted HTML for display
19
+ color_hsl: str # Raw HSL color string
20
+
21
+ # Define the status thresholds from highest score to lowest score
22
+ VIBE_THRESHOLDS: List[VibeThreshold] = [
23
+ VibeThreshold(score=0.8, status="✨ VIBE:HIGH"),
24
+ VibeThreshold(score=0.5, status="πŸ‘ VIBE:GOOD"),
25
+ VibeThreshold(score=0.2, status="😐 VIBE:FLAT"),
26
+ VibeThreshold(score=0.0, status="πŸ‘Ž VIBE:LOW&nbsp;"), # Base case for scores < 0.2
27
+ ]
28
+
29
+ # --- Utility Functions ---
30
+
31
+ def map_score_to_vibe(score: float) -> VibeResult:
32
+ """
33
+ Maps a cosine similarity score to a VibeResult containing status, HTML, and color.
34
+ """
35
+ # 1. Clamp score for safety
36
+ clamped_score = max(0.0, min(1.0, score))
37
+
38
+ # 2. Color Calculation
39
+ hue = floor(clamped_score * 120) # Linear interpolation: 0 (Red) -> 120 (Green)
40
+ color_hsl = f"hsl({hue}, 80%, 50%)"
41
+
42
+ # 3. Status Determination
43
+ status_text: str = VIBE_THRESHOLDS[-1].status # Default to the lowest status
44
+ for threshold in VIBE_THRESHOLDS:
45
+ if clamped_score >= threshold.score:
46
+ status_text = threshold.status
47
+ break
48
+
49
+ # 4. Create the pre-formatted HTML for display
50
+ status_html = f"<span style='color: {color_hsl}; font-weight: bold;'>{status_text}</span>"
51
+
52
+ return VibeResult(raw_score=score, status_html=status_html, color_hsl=color_hsl)
53
+
54
+
55
+ # --- Core Logic Class ---
56
+
57
+ class VibeChecker:
58
+ """
59
+ Handles similarity scoring using a SentenceTransformer model and a pre-set anchor query.
60
+ """
61
+ def __init__(self, model: SentenceTransformer, query_anchor: str, task_name: str):
62
+ self.model = model
63
+ self.query_anchor = query_anchor
64
+ self.task_name = task_name
65
+
66
+ # Pre-calculate the anchor embedding for efficiency
67
+ self.query_embedding = self.model.encode(
68
+ self.query_anchor,
69
+ prompt_name=self.task_name,
70
+ normalize_embeddings=True
71
+ )
72
+
73
+ def check(self, text: str) -> VibeResult:
74
+ """
75
+ Calculates the "vibe" of a given text against the pre-configured anchor.
76
+ """
77
+ title_embedding = self.model.encode(
78
+ text,
79
+ prompt_name=self.task_name,
80
+ normalize_embeddings=True
81
+ )
82
+ # Use dot product for similarity with normalized embeddings
83
+ score: float = util.dot_score(self.query_embedding, title_embedding).item()
84
+
85
+ return map_score_to_vibe(score)