SVR_Predict_Stocks / src /inference.py
Reality8081's picture
Update src
35beba6
import os
import joblib
import pandas as pd
from datetime import datetime, timedelta
from huggingface_hub import hf_hub_download
from src.data_processing import load_data, clean_data, generate_technical_features
REPO_ID = "Reality8081/Predict_Stock_SVR_Linear" # << THAY ĐỔI DÒNG NÀY TƯƠNG TỰ
MARKET_SYMBOL = "^GSPC"
HORIZONS = [1, 7, 21]
# Tự động tải models từ Hugging Face nếu chưa có tại local
def download_model_if_not_exists(filename):
local_path = os.path.join("models", filename)
if not os.path.exists(local_path):
os.makedirs("models", exist_ok=True)
path = hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir="models")
return path
return local_path
def predict_horizons(ticker, model_name):
end_date = datetime.now()
start_date = end_date - timedelta(days=150)
df_raw = load_data([ticker], MARKET_SYMBOL, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
df_clean = clean_data(df_raw)
df_features, X, _ = generate_technical_features(df_clean, is_inference=True)
if len(X) == 0: raise ValueError(f"Không đủ dữ liệu cho {ticker}.")
latest_X = X.iloc[[-1]]
latest_data = df_features.iloc[-1]
last_close = latest_data['Close']
last_date = latest_data['Date'].strftime('%Y-%m-%d')
predictions = {1: {}, 7: {}, 21: {}}
for h in HORIZONS:
if model_name in ["Linear Regression", "Cả Hai"]:
scaler_lr = joblib.load(download_model_if_not_exists(f'scaler_lr_{h}d.pkl'))
model_lr = joblib.load(download_model_if_not_exists(f'model_lr_{h}d.pkl'))
pred_return_lr = model_lr.predict(scaler_lr.transform(latest_X))[0]
predictions[h]["Linear Regression"] = {
"pred_return": pred_return_lr,
"pred_price": last_close * (1 + pred_return_lr)
}
if model_name in ["SVR", "Cả Hai"]:
scaler_svr = joblib.load(download_model_if_not_exists(f'scaler_svr_{h}d.pkl'))
model_svr = joblib.load(download_model_if_not_exists(f'model_svr_{h}d.pkl'))
pred_return_svr = model_svr.predict(scaler_svr.transform(latest_X))[0]
predictions[h]["SVR"] = {
"pred_return": pred_return_svr,
"pred_price": last_close * (1 + pred_return_svr)
}
historical_30 = df_features[['Date', 'Close']].tail(30)
return predictions, last_close, last_date, historical_30