Ken Sang Tang commited on
Commit
47d5964
·
verified ·
1 Parent(s): b230721

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -15,14 +15,16 @@ MODEL_NAME = "databricks/dolly-v2-3b"
15
  FINBERT_MODEL_NAME = "yiyanghkust/finbert-tone"
16
  SYMBOL = '^KLSE' # KLCI index symbol
17
  START_DATE = '2020-01-01'
18
- END_DATE = '2023-12-31'
19
 
20
  # Initialize Alpaca API
21
  api = REST(ALPACA_API_KEY, ALPACA_SECRET_KEY, ALPACA_BASE_URL)
22
 
23
  # Load Models
24
  dolly_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
- dolly_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
 
26
  finbert_tokenizer = AutoTokenizer.from_pretrained(FINBERT_MODEL_NAME)
27
  finbert_model = AutoModelForSequenceClassification.from_pretrained(FINBERT_MODEL_NAME)
28
 
@@ -59,7 +61,7 @@ def generate_prediction(prompt):
59
  return dolly_tokenizer.decode(outputs[0], skip_special_tokens=True)
60
 
61
  # Step 5: Execute Trade with Alpaca
62
- def execute_trade(signal, symbol='AAPL', qty=1):
63
  print(f"Executing trade signal: {signal}")
64
  if signal == "buy":
65
  api.submit_order(symbol=symbol, qty=qty, side='buy', type='market', time_in_force='gtc')
 
15
  FINBERT_MODEL_NAME = "yiyanghkust/finbert-tone"
16
  SYMBOL = '^KLSE' # KLCI index symbol
17
  START_DATE = '2020-01-01'
18
+ END_DATE = '2024-07-31'
19
 
20
  # Initialize Alpaca API
21
  api = REST(ALPACA_API_KEY, ALPACA_SECRET_KEY, ALPACA_BASE_URL)
22
 
23
  # Load Models
24
  dolly_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
+ dolly_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype.float16) # half precision
26
+ dolly_model.gradient_checkpointing_enable() # memory-efficient loading
27
+ print("Loading FinBERT model...")
28
  finbert_tokenizer = AutoTokenizer.from_pretrained(FINBERT_MODEL_NAME)
29
  finbert_model = AutoModelForSequenceClassification.from_pretrained(FINBERT_MODEL_NAME)
30
 
 
61
  return dolly_tokenizer.decode(outputs[0], skip_special_tokens=True)
62
 
63
  # Step 5: Execute Trade with Alpaca
64
+ def execute_trade(signal, symbol='Genting Bhd', qty=1):
65
  print(f"Executing trade signal: {signal}")
66
  if signal == "buy":
67
  api.submit_order(symbol=symbol, qty=qty, side='buy', type='market', time_in_force='gtc')