| import os |
| import tensorflow as tf |
| import transformers |
| from tensorflow import keras |
| from transformers import BertTokenizer, TFBertModel |
| import pandas as pd |
| from datetime import date, timedelta |
| import requests |
| import time |
| from typing import List, Optional, Dict, Any |
|
|
| from transformers import AutoTokenizer |
| |
| Bert_model_name = "hfl/rbt3" |
|
|
| class BertPredictor: |
| """ |
| 用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。 |
| """ |
| def __init__(self, tokenizer_name: str = Bert_model_name, max_news_per_keyword: int = 5): |
| """ |
| 初始化預測器,載入分詞器、預訓練模型並獲取新聞。 |
| |
| Args: |
| tokenizer_name (str): BERT 分詞器的名稱。 |
| max_news_per_keyword (int): 每個關鍵字要抓取的新聞最大數量。 |
| """ |
| |
| self.current_dir = os.path.dirname(os.path.abspath(__file__)) |
| self.model_path = os.path.join(self.current_dir, 'Best-complete-model.h5') |
| |
| |
| today_date_str = date.today().strftime('%Y-%m-%d') |
| self.news_csv_path = os.path.join(self.current_dir, f'news_{today_date_str}.csv') |
| |
| self.news_csv_path = os.path.join(self.current_dir, "news_2025-09-12.csv") |
| |
| |
| self.target_date = date.today() - timedelta(days=1) |
| self.target_date_str = self.target_date.strftime('%Y-%m-%d') |
| |
| |
| self.api_key = "fd12e84a158c7d9eaf31627aaae0927a" |
| self.base_url = "https://gnews.io/api/v4/search" |
| self.keywords = ["Fed", "Interest Rates", "Inflation", "Tariffs", "ADR", "Treasury Yields"] |
| self.max_news_per_keyword = max_news_per_keyword |
|
|
| |
| self.text_max_length = 256 |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| |
| |
| |
| print("正在加載模型...") |
| self.model = keras.models.load_model( |
| self.model_path, |
| custom_objects={'TFBertModel': TFBertModel} |
| ) |
| print("模型加載完成。") |
|
|
| |
| self._check_file_and_get_news_if_needed() |
|
|
|
|
| |
|
|
| def _encode_texts(self, texts: list): |
| """將文本轉換為 BERT 輸入格式 (input_ids, attention_mask)""" |
| return self.tokenizer( |
| texts, |
| max_length=self.text_max_length, |
| padding='max_length', |
| truncation=True, |
| return_tensors='tf' |
| ) |
|
|
| def _predict(self, new_text: str) -> float: |
| """ |
| 對單一新聞文本進行預測。 |
| |
| Args: |
| new_text (str): 待預測的新聞文本。 |
| |
| Returns: |
| float: 預測的股市影響分數。 |
| """ |
| new_encoding = self._encode_texts([new_text]) |
| predicted_score = self.model.predict(dict(new_encoding), verbose=0)[0][0] |
| return float(predicted_score) |
|
|
| def _check_file_and_get_news_if_needed(self): |
| """ |
| 檢查今天的 news csv 是否存在。如果不存在,則呼叫 _get_news() 進行抓取。 |
| """ |
| if not os.path.exists(self.news_csv_path): |
| print(f"找不到今天的檔案 '{os.path.basename(self.news_csv_path)}'。") |
| self._get_news() |
| else: |
| print(f"已找到今天的檔案 '{os.path.basename(self.news_csv_path)}',將跳過新聞抓取步驟。") |
|
|
| def _get_news(self): |
| """ |
| 使用 GNews API 抓取目標日期(昨天)的新聞,即時預測分數並儲存。 |
| """ |
| print("開始執行新聞抓取與即時預測...") |
| print(f"搜尋日期設定為:{self.target_date_str} (將存檔至檔名含今日日期的檔案)") |
|
|
| results = [] |
| for kw in self.keywords: |
| params = { |
| "q": kw, "lang": "en", "country": "us", "max": self.max_news_per_keyword, |
| "in": "title,description", "apikey": self.api_key, |
| "from": f"{self.target_date_str}T00:00:00Z", |
| "to": f"{self.target_date_str}T23:59:59Z" |
| } |
| try: |
| response = requests.get(self.base_url, params=params) |
| response.raise_for_status() |
| data = response.json() |
| print(f"關鍵字 '{kw}' 成功抓取到: {data.get('totalArticles', 0)} 則新聞") |
| if "articles" in data: |
| for article in data["articles"]: |
| published_date = pd.to_datetime(article['publishedAt']).strftime('%Y-%m-%d') |
| news_content = f"{article['title']} - {article.get('description', '')}" |
| score = self._predict(news_content) |
| results.append({ |
| "時間": published_date, |
| "分數": score, |
| "內容": news_content |
| }) |
| except requests.exceptions.RequestException as e: |
| print(f"錯誤:API 請求失敗 - {e}") |
| continue |
| finally: |
| time.sleep(0.5) |
|
|
| if not results: |
| print("抓取完成。未找到任何相關新聞。") |
| df_to_save = pd.DataFrame(columns=['時間', '分數', '內容']) |
| else: |
| print(f"成功抓取並預測 {len(results)} 筆新聞。") |
| df_to_save = pd.DataFrame(results) |
| |
| try: |
| print(f"正在將結果寫入檔案 '{self.news_csv_path}'...") |
| df_to_save.to_csv(self.news_csv_path, index=False, encoding='utf-8-sig') |
| print(f"成功!檔案已儲存至 '{self.news_csv_path}'。") |
| except IOError as e: |
| print(f"錯誤:寫入檔案失敗 - {e}") |
|
|
| |
|
|
| def get_news_index(self) -> Optional[float]: |
| """ |
| 從今天的 news csv 檔案中讀取所有新聞分數並回傳其平均值。 |
| |
| Returns: |
| float or None: 所有新聞的平均分數,如果檔案不存在或為空則回傳 None。 |
| """ |
| try: |
| df = pd.read_csv(self.news_csv_path) |
| if df.empty or '分數' not in df.columns: |
| print(f"'{self.news_csv_path}' 為空或缺少 '分數' 欄位。") |
| return None |
| |
| average_score = pd.to_numeric(df['分數'], errors='coerce').mean() |
| return average_score if pd.notna(average_score) else None |
| |
| except FileNotFoundError: |
| print(f"錯誤:找不到檔案 '{self.news_csv_path}'。") |
| return None |
| except Exception as e: |
| print(f"讀取或計算 CSV 檔案時發生錯誤:{e}") |
| return None |
| |
| def get_news(self) -> Optional[List[str]]: |
| """ |
| 讀取今天的 news csv 檔案,並以 list 格式回傳分數絕對值最高的三則新聞內容。 |
| """ |
| try: |
| df = pd.read_csv(self.news_csv_path) |
| df['分數'] = pd.to_numeric(df['分數'], errors='coerce') |
| df.dropna(subset=['分數'], inplace=True) |
| if df.empty: |
| return [] |
|
|
| df['abs_score'] = df['分數'].abs() |
| top_3_news_df = df.sort_values(by='abs_score', ascending=False).head(3) |
| |
| |
| return top_3_news_df['內容'].tolist() |
|
|
| except FileNotFoundError: |
| print(f"錯誤:找不到檔案 '{self.news_csv_path}'。") |
| return None |
| except Exception as e: |
| print(f"讀取或處理 CSV 檔案時發生錯誤:{e}") |
| return None |
|
|
| |
| if __name__ == "__main__": |
| if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Best-complete-model.h5')): |
| print("錯誤:找不到模型文件 'Best-complete-model.h5'。請先訓練模型並確保它已保存。") |
| else: |
| predictor = BertPredictor(max_news_per_keyword=3) |
| print("\n" + "="*30) |
| avg_score = predictor.get_news_index() |
| if avg_score is not None: |
| print(f"從新聞檔案中計算出的平均分數為:{avg_score:.4f}") |
| else: |
| print("無法計算新聞檔案中的平均分數。") |
|
|
| print("\n" + "="*30) |
| top_news_content = predictor.get_news() |
| if top_news_content: |
| print("\n分數絕對值最高的三則新聞內容:") |
| for i, content in enumerate(top_news_content): |
| print(f" {i+1}. {content}") |
| elif top_news_content == []: |
| print("新聞檔案中無有效內容可顯示。") |
| else: |
| print("無法獲取最高分新聞。") |
|
|
|
|