| import pandas as pd |
| import numpy as np |
| import logging |
| import os |
| from typing import Dict, Any, Optional |
| from datetime import datetime, timedelta |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def load_data(config: Dict[str, Any]) -> Optional[pd.DataFrame]: |
| """ |
| Load market data based on configuration. |
| |
| Args: |
| config: Configuration dictionary |
| |
| Returns: |
| DataFrame with market data or None if error |
| """ |
| try: |
| data_source = config['data_source']['type'] |
| logger.info(f"Loading data from source: {data_source}") |
| |
| if data_source == 'alpaca': |
| return _load_alpaca_data(config) |
| elif data_source == 'csv': |
| return _load_csv_data(config) |
| elif data_source == 'synthetic': |
| return _load_synthetic_data(config) |
| else: |
| logger.error(f"Unsupported data source: {data_source}") |
| return None |
| |
| except Exception as e: |
| logger.error(f"Error loading data: {e}") |
| return None |
|
|
| def _load_alpaca_data(config: Dict[str, Any]) -> Optional[pd.DataFrame]: |
| """Load market data from Alpaca""" |
| try: |
| from .alpaca_broker import AlpacaBroker |
| |
| |
| alpaca_broker = AlpacaBroker(config) |
| |
| |
| symbol = config['trading']['symbol'] |
| timeframe = config['trading']['timeframe'] |
| |
| |
| tf_map = { |
| '1m': '1Min', |
| '5m': '5Min', |
| '15m': '15Min', |
| '1h': '1Hour', |
| '1d': '1Day' |
| } |
| alpaca_timeframe = tf_map.get(timeframe, '1Min') |
| |
| |
| data = alpaca_broker.get_market_data( |
| symbol=symbol, |
| timeframe=alpaca_timeframe, |
| limit=1000 |
| ) |
| |
| if data is not None and not data.empty: |
| logger.info(f"Loaded {len(data)} data points from Alpaca for {symbol}") |
| return data |
| else: |
| logger.error(f"No data returned from Alpaca for {symbol}") |
| return None |
| |
| except Exception as e: |
| logger.error(f"Error loading Alpaca data: {e}") |
| return None |
|
|
| def _load_csv_data(config: Dict[str, Any]) -> Optional[pd.DataFrame]: |
| """Load market data from CSV file""" |
| try: |
| file_path = config['data_source']['path'] |
| |
| if not os.path.exists(file_path): |
| logger.error(f"CSV file not found: {file_path}") |
| return None |
| |
| |
| data = pd.read_csv(file_path) |
| |
| |
| if 'date' in data.columns and 'timestamp' not in data.columns: |
| data = data.rename(columns={'date': 'timestamp'}) |
| |
| |
| required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] |
| missing_columns = [col for col in required_columns if col not in data.columns] |
| |
| if missing_columns: |
| logger.error(f"Missing required columns: {missing_columns}") |
| return None |
| |
| |
| data['timestamp'] = pd.to_datetime(data['timestamp']) |
| |
| |
| data = data.sort_values('timestamp').reset_index(drop=True) |
| |
| logger.info(f"Loaded {len(data)} data points from CSV: {file_path}") |
| return data |
| |
| except Exception as e: |
| logger.error(f"Error loading CSV data: {e}") |
| return None |
|
|
| def _load_synthetic_data(config: Dict[str, Any]) -> Optional[pd.DataFrame]: |
| """Load or generate synthetic market data""" |
| try: |
| synthetic_config = config.get('synthetic_data', {}) |
| data_path = synthetic_config.get('data_path', 'data/synthetic_market_data.csv') |
| |
| |
| if os.path.exists(data_path): |
| logger.info(f"Loading existing synthetic data from: {data_path}") |
| return _load_csv_data({'data_source': {'path': data_path}}) |
| |
| |
| logger.info("Generating new synthetic market data") |
| from .synthetic_data_generator import SyntheticDataGenerator |
| |
| generator = SyntheticDataGenerator(config) |
| data = generator.generate_data() |
| |
| if data is not None and not data.empty: |
| |
| os.makedirs(os.path.dirname(data_path), exist_ok=True) |
| data.to_csv(data_path, index=False) |
| logger.info(f"Saved synthetic data to: {data_path}") |
| return data |
| else: |
| logger.error("Failed to generate synthetic data") |
| return None |
| |
| except Exception as e: |
| logger.error(f"Error loading synthetic data: {e}") |
| return None |
|
|
| def validate_data(data: pd.DataFrame) -> bool: |
| """ |
| Validate market data quality. |
| |
| Args: |
| data: DataFrame with market data |
| |
| Returns: |
| True if data is valid, False otherwise |
| """ |
| try: |
| if data is None or data.empty: |
| logger.error("Data is None or empty") |
| return False |
| |
| |
| if 'date' in data.columns and 'timestamp' not in data.columns: |
| data = data.rename(columns={'date': 'timestamp'}) |
| |
| |
| required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] |
| missing_columns = [col for col in required_columns if col not in data.columns] |
| |
| if missing_columns: |
| logger.error(f"Missing required columns: {missing_columns}") |
| return False |
| |
| |
| nan_counts = data[required_columns].isna().sum() |
| if nan_counts.sum() > 0: |
| logger.warning(f"Found NaN values: {nan_counts.to_dict()}") |
| |
| data.dropna(subset=required_columns, inplace=True) |
| logger.info(f"Removed {nan_counts.sum()} rows with NaN values") |
| |
| |
| price_columns = ['open', 'high', 'low', 'close'] |
| negative_prices = data[price_columns] < 0 |
| if negative_prices.any().any(): |
| logger.error("Found negative prices in data") |
| return False |
| |
| |
| zero_volumes = data['volume'] == 0 |
| if zero_volumes.sum() > len(data) * 0.5: |
| logger.warning("High percentage of zero volumes detected") |
| |
| |
| invalid_ohlc = ( |
| (data['high'] < data['low']) | |
| (data['open'] > data['high']) | |
| (data['close'] > data['high']) | |
| (data['open'] < data['low']) | |
| (data['close'] < data['low']) |
| ) |
| |
| if invalid_ohlc.any(): |
| logger.error("Found invalid OHLC relationships") |
| return False |
| |
| |
| if 'timestamp' in data.columns: |
| timestamps = pd.to_datetime(data['timestamp']) |
| if not timestamps.is_monotonic_increasing: |
| logger.warning("Timestamps are not in ascending order") |
| data = data.sort_values('timestamp').reset_index(drop=True) |
| |
| logger.info(f"Data validation passed: {len(data)} valid records") |
| return True |
| |
| except Exception as e: |
| logger.error(f"Error validating data: {e}") |
| return False |
|
|
| def add_technical_indicators(data: pd.DataFrame) -> pd.DataFrame: |
| """ |
| Add technical indicators to market data. |
| |
| Args: |
| data: DataFrame with OHLCV data |
| |
| Returns: |
| DataFrame with technical indicators added |
| """ |
| try: |
| df = data.copy() |
| |
| |
| df['sma_20'] = df['close'].rolling(window=20).mean() |
| df['sma_50'] = df['close'].rolling(window=50).mean() |
| df['sma_200'] = df['close'].rolling(window=200).mean() |
| |
| |
| df['ema_12'] = df['close'].ewm(span=12).mean() |
| df['ema_26'] = df['close'].ewm(span=26).mean() |
| |
| |
| df['macd'] = df['ema_12'] - df['ema_26'] |
| df['macd_signal'] = df['macd'].ewm(span=9).mean() |
| df['macd_histogram'] = df['macd'] - df['macd_signal'] |
| |
| |
| delta = df['close'].diff() |
| gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() |
| loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() |
| rs = gain / loss |
| df['rsi'] = 100 - (100 / (1 + rs)) |
| |
| |
| df['bb_middle'] = df['close'].rolling(window=20).mean() |
| bb_std = df['close'].rolling(window=20).std() |
| df['bb_upper'] = df['bb_middle'] + (bb_std * 2) |
| df['bb_lower'] = df['bb_middle'] - (bb_std * 2) |
| |
| |
| high_low = df['high'] - df['low'] |
| high_close = np.abs(df['high'] - df['close'].shift()) |
| low_close = np.abs(df['low'] - df['close'].shift()) |
| true_range = np.maximum(high_low, np.maximum(high_close, low_close)) |
| df['atr'] = true_range.rolling(window=14).mean() |
| |
| |
| df['volume_sma'] = df['volume'].rolling(window=20).mean() |
| df['volume_ratio'] = df['volume'] / df['volume_sma'] |
| |
| |
| df['price_change'] = df['close'].pct_change() |
| df['price_change_5'] = df['close'].pct_change(periods=5) |
| df['price_change_20'] = df['close'].pct_change(periods=20) |
| |
| logger.info("Technical indicators added successfully") |
| return df |
| |
| except Exception as e: |
| logger.error(f"Error adding technical indicators: {e}") |
| return data |
|
|
| def get_latest_data(data: pd.DataFrame, n_periods: int = 100) -> pd.DataFrame: |
| """ |
| Get the latest n periods of data. |
| |
| Args: |
| data: DataFrame with market data |
| n_periods: Number of periods to return |
| |
| Returns: |
| DataFrame with latest n periods |
| """ |
| try: |
| if len(data) <= n_periods: |
| return data |
| |
| return data.tail(n_periods).reset_index(drop=True) |
| |
| except Exception as e: |
| logger.error(f"Error getting latest data: {e}") |
| return data |
|
|