hk-trading-platform / utils /cache_manager.py
Humphreykowl's picture
Upload 2 files
5351a1e verified
"""
缓存管理模块
Cache Management Module
"""
import pickle
import json
import hashlib
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path
import logging
from abc import ABC, abstractmethod
from collections import OrderedDict
import threading
from functools import wraps
import pandas as pd
import numpy as np
class CacheBackend(ABC):
"""缓存后端抽象基类"""
@abstractmethod
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
pass
@abstractmethod
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存值"""
pass
@abstractmethod
def delete(self, key: str) -> bool:
"""删除缓存"""
pass
@abstractmethod
def clear(self) -> bool:
"""清空缓存"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""检查缓存是否存在"""
pass
@abstractmethod
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
pass
class MemoryCache(CacheBackend):
"""内存缓存后端"""
def __init__(self, max_size: int = 1000, default_ttl: int = 3600):
self.max_size = max_size
self.default_ttl = default_ttl
self._cache = OrderedDict()
self._expiry_times = {}
self._lock = threading.RLock()
self._stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'deletes': 0,
'evictions': 0
}
self.logger = logging.getLogger(__name__)
def _is_expired(self, key: str) -> bool:
"""检查缓存是否过期"""
if key not in self._expiry_times:
return False
return time.time() > self._expiry_times[key]
def _evict_expired(self):
"""清理过期缓存"""
current_time = time.time()
expired_keys = [
key for key, expiry_time in self._expiry_times.items()
if current_time > expiry_time
]
for key in expired_keys:
self._remove_key(key)
self._stats['evictions'] += 1
def _evict_lru(self):
"""LRU淘汰"""
while len(self._cache) >= self.max_size:
oldest_key = next(iter(self._cache))
self._remove_key(oldest_key)
self._stats['evictions'] += 1
def _remove_key(self, key: str):
"""移除键"""
self._cache.pop(key, None)
self._expiry_times.pop(key, None)
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
with self._lock:
self._evict_expired()
if key not in self._cache:
self._stats['misses'] += 1
return None
if self._is_expired(key):
self._remove_key(key)
self._stats['misses'] += 1
return None
# 移动到末尾(LRU)
value = self._cache.pop(key)
self._cache[key] = value
self._stats['hits'] += 1
return value
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存值"""
try:
with self._lock:
self._evict_expired()
self._evict_lru()
# 设置过期时间
if ttl is None:
ttl = self.default_ttl
expiry_time = time.time() + ttl if ttl > 0 else float('inf')
# 如果key已存在,移除旧值
if key in self._cache:
self._cache.pop(key)
self._cache[key] = value
self._expiry_times[key] = expiry_time
self._stats['sets'] += 1
return True
except Exception as e:
self.logger.error(f"Failed to set cache key {key}: {e}")
return False
def delete(self, key: str) -> bool:
"""删除缓存"""
with self._lock:
if key in self._cache:
self._remove_key(key)
self._stats['deletes'] += 1
return True
return False
def clear(self) -> bool:
"""清空缓存"""
with self._lock:
self._cache.clear()
self._expiry_times.clear()
return True
def exists(self, key: str) -> bool:
"""检查缓存是否存在"""
with self._lock:
if key not in self._cache:
return False
if self._is_expired(key):
self._remove_key(key)
return False
return True
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
with self._lock:
self._evict_expired()
total_requests = self._stats['hits'] + self._stats['misses']
hit_rate = self._stats['hits'] / total_requests if total_requests > 0 else 0
return {
**self._stats,
'size': len(self._cache),
'max_size': self.max_size,
'hit_rate': hit_rate,
'memory_usage': self._estimate_memory_usage()
}
def _estimate_memory_usage(self) -> int:
"""估算内存使用量"""
try:
total_size = 0
for key, value in self._cache.items():
total_size += len(pickle.dumps(key)) + len(pickle.dumps(value))
return total_size
except Exception:
return 0
class FileCache(CacheBackend):
"""文件缓存后端"""
def __init__(self, cache_dir: str = "./cache", default_ttl: int = 3600):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.default_ttl = default_ttl
self.logger = logging.getLogger(__name__)
self._stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'deletes': 0,
'evictions': 0
}
def _get_file_path(self, key: str) -> Path:
"""获取文件路径"""
# 使用MD5哈希避免文件名问题
key_hash = hashlib.md5(key.encode()).hexdigest()
return self.cache_dir / f"{key_hash}.cache"
def _get_meta_path(self, key: str) -> Path:
"""获取元数据文件路径"""
key_hash = hashlib.md5(key.encode()).hexdigest()
return self.cache_dir / f"{key_hash}.meta"
def _is_expired(self, key: str) -> bool:
"""检查缓存是否过期"""
meta_path = self._get_meta_path(key)
if not meta_path.exists():
return True
try:
with open(meta_path, 'r') as f:
meta = json.load(f)
expiry_time = meta.get('expiry_time', 0)
return time.time() > expiry_time
except Exception:
return True
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
file_path = self._get_file_path(key)
if not file_path.exists() or self._is_expired(key):
self._stats['misses'] += 1
return None
try:
with open(file_path, 'rb') as f:
value = pickle.load(f)
self._stats['hits'] += 1
return value
except Exception as e:
self.logger.error(f"Failed to load cache file {file_path}: {e}")
self._stats['misses'] += 1
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存值"""
try:
file_path = self._get_file_path(key)
meta_path = self._get_meta_path(key)
# 保存数据
with open(file_path, 'wb') as f:
pickle.dump(value, f)
# 保存元数据
if ttl is None:
ttl = self.default_ttl
expiry_time = time.time() + ttl if ttl > 0 else float('inf')
meta_data = {
'key': key,
'created_time': time.time(),
'expiry_time': expiry_time,
'ttl': ttl
}
with open(meta_path, 'w') as f:
json.dump(meta_data, f)
self._stats['sets'] += 1
return True
except Exception as e:
self.logger.error(f"Failed to save cache file for key {key}: {e}")
return False
def delete(self, key: str) -> bool:
"""删除缓存"""
file_path = self._get_file_path(key)
meta_path = self._get_meta_path(key)
deleted = False
if file_path.exists():
file_path.unlink()
deleted = True
if meta_path.exists():
meta_path.unlink()
deleted = True
if deleted:
self._stats['deletes'] += 1
return deleted
def clear(self) -> bool:
"""清空缓存"""
try:
for cache_file in self.cache_dir.glob("*.cache"):
cache_file.unlink()
for meta_file in self.cache_dir.glob("*.meta"):
meta_file.unlink()
return True
except Exception as e:
self.logger.error(f"Failed to clear cache: {e}")
return False
def exists(self, key: str) -> bool:
"""检查缓存是否存在"""
file_path = self._get_file_path(key)
return file_path.exists() and not self._is_expired(key)
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
cache_files = list(self.cache_dir.glob("*.cache"))
total_size = sum(f.stat().st_size for f in cache_files)
total_requests = self._stats['hits'] + self._stats['misses']
hit_rate = self._stats['hits'] / total_requests if total_requests > 0 else 0
return {
**self._stats,
'size': len(cache_files),
'disk_usage': total_size,
'hit_rate': hit_rate,
'cache_dir': str(self.cache_dir)
}
def cleanup_expired(self) -> int:
"""清理过期文件"""
expired_count = 0
for meta_file in self.cache_dir.glob("*.meta"):
try:
with open(meta_file, 'r') as f:
meta = json.load(f)
if time.time() > meta.get('expiry_time', 0):
key = meta.get('key', '')
if self.delete(key):
expired_count += 1
self._stats['evictions'] += 1
except Exception as e:
self.logger.warning(f"Failed to check expiry for {meta_file}: {e}")
return expired_count
class CacheManager:
"""缓存管理器"""
def __init__(self, backend: str = 'memory', **backend_kwargs):
self.logger = logging.getLogger(__name__)
# 初始化后端
if backend == 'memory':
self.backend = MemoryCache(**backend_kwargs)
elif backend == 'file':
self.backend = FileCache(**backend_kwargs)
else:
raise ValueError(f"Unsupported cache backend: {backend}")
self.backend_type = backend
self._key_prefix = "hk_trading_"
def _make_key(self, key: str) -> str:
"""生成缓存键"""
return f"{self._key_prefix}{key}"
def get(self, key: str, default: Any = None) -> Any:
"""获取缓存"""
cache_key = self._make_key(key)
value = self.backend.get(cache_key)
return value if value is not None else default
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存"""
cache_key = self._make_key(key)
return self.backend.set(cache_key, value, ttl)
def delete(self, key: str) -> bool:
"""删除缓存"""
cache_key = self._make_key(key)
return self.backend.delete(cache_key)
def exists(self, key: str) -> bool:
"""检查缓存是否存在"""
cache_key = self._make_key(key)
return self.backend.exists(cache_key)
def clear(self) -> bool:
"""清空缓存"""
return self.backend.clear()
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
stats = self.backend.get_stats()
stats['backend_type'] = self.backend_type
return stats
def cache_dataframe(self, key: str, df: pd.DataFrame, ttl: Optional[int] = None) -> bool:
"""缓存DataFrame"""
try:
# 转换为字典格式以便序列化
df_dict = {
'data': df.to_dict('records'),
'index': df.index.tolist(),
'columns': df.columns.tolist(),
'dtypes': df.dtypes.to_dict()
}
return self.set(key, df_dict, ttl)
except Exception as e:
self.logger.error(f"Failed to cache DataFrame: {e}")
return False
def get_dataframe(self, key: str) -> Optional[pd.DataFrame]:
"""获取缓存的DataFrame"""
try:
df_dict = self.get(key)
if df_dict is None:
return None
# 重构DataFrame
df = pd.DataFrame(df_dict['data'])
if 'index' in df_dict:
df.index = df_dict['index']
# 恢复数据类型
if 'dtypes' in df_dict:
for col, dtype in df_dict['dtypes'].items():
if col in df.columns:
try:
df[col] = df[col].astype(dtype)
except Exception:
pass # 忽略类型转换错误
return df
except Exception as e:
self.logger.error(f"Failed to load cached DataFrame: {e}")
return None
def cache_array(self, key: str, array: np.ndarray, ttl: Optional[int] = None) -> bool:
"""缓存NumPy数组"""
try:
array_dict = {
'data': array.tolist(),
'shape': array.shape,
'dtype': str(array.dtype)
}
return self.set(key, array_dict, ttl)
except Exception as e:
self.logger.error(f"Failed to cache array: {e}")
return False
def get_array(self, key: str) -> Optional[np.ndarray]:
"""获取缓存的NumPy数组"""
try:
array_dict = self.get(key)
if array_dict is None:
return None
array = np.array(array_dict['data'], dtype=array_dict['dtype'])
return array.reshape(array_dict['shape'])
except Exception as e:
self.logger.error(f"Failed to load cached array: {e}")
return None
# 缓存装饰器
def cached(ttl: int = 3600, key_func: Optional[callable] = None, cache_manager: Optional[CacheManager] = None):
"""缓存装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 获取缓存管理器
cm = cache_manager or get_default_cache_manager()
# 生成缓存键
if key_func:
cache_key = key_func(*args, **kwargs)
else:
# 默认键生成策略
key_parts = [func.__name__]
key_parts.extend(str(arg) for arg in args)
key_parts.extend(f"{k}:{v}" for k, v in sorted(kwargs.items()))
cache_key = hashlib.md5("_".join(key_parts).encode()).hexdigest()
# 尝试从缓存获取
cached_result = cm.get(cache_key)
if cached_result is not None:
return cached_result
# 执行函数并缓存结果
result = func(*args, **kwargs)
cm.set(cache_key, result, ttl)
return result
return wrapper
return decorator
# 全局缓存管理器
_default_cache_manager = None
def get_default_cache_manager() -> CacheManager:
"""获取默认缓存管理器"""
global _default_cache_manager
if _default_cache_manager is None:
_default_cache_manager = CacheManager(backend='memory', max_size=1000, default_ttl=3600)
return _default_cache_manager
def set_default_cache_manager(cache_manager: CacheManager):
"""设置默认缓存管理器"""
global _default_cache_manager
_default_cache_manager = cache_manager
class CacheStats:
"""缓存统计分析器"""
def __init__(self, cache_manager: CacheManager):
self.cache_manager = cache_manager
def get_detailed_stats(self) -> Dict[str, Any]:
"""获取详细统计信息"""
basic_stats = self.cache_manager.get_stats()
detailed_stats = {
**basic_stats,
'efficiency': self._calculate_efficiency(basic_stats),
'memory_pressure': self._calculate_memory_pressure(basic_stats),
'recommendations': self._generate_recommendations(basic_stats)
}
return detailed_stats
def _calculate_efficiency(self, stats: Dict[str, Any]) -> Dict[str, float]:
"""计算缓存效率"""
hit_rate = stats.get('hit_rate', 0)
# 效率评分
if hit_rate >= 0.8:
efficiency_score = 'excellent'
elif hit_rate >= 0.6:
efficiency_score = 'good'
elif hit_rate >= 0.4:
efficiency_score = 'fair'
else:
efficiency_score = 'poor'
return {
'hit_rate': hit_rate,
'efficiency_score': efficiency_score,
'total_requests': stats.get('hits', 0) + stats.get('misses', 0)
}
def _calculate_memory_pressure(self, stats: Dict[str, Any]) -> Dict[str, Any]:
"""计算内存压力"""
current_size = stats.get('size', 0)
max_size = stats.get('max_size', 1000)
utilization = current_size / max_size if max_size > 0 else 0
if utilization >= 0.9:
pressure_level = 'high'
elif utilization >= 0.7:
pressure_level = 'medium'
else:
pressure_level = 'low'
return {
'utilization': utilization,
'pressure_level': pressure_level,
'current_size': current_size,
'max_size': max_size,
'evictions': stats.get('evictions', 0)
}
def _generate_recommendations(self, stats: Dict[str, Any]) -> List[str]:
"""生成优化建议"""
recommendations = []
hit_rate = stats.get('hit_rate', 0)
evictions = stats.get('evictions', 0)
if hit_rate < 0.5:
recommendations.append("Low hit rate detected. Consider increasing TTL or cache size.")
if evictions > 0:
recommendations.append(f"Cache evictions occurred ({evictions}). Consider increasing max_size.")
utilization = stats.get('size', 0) / stats.get('max_size', 1000)
if utilization > 0.9:
recommendations.append("High cache utilization. Consider increasing cache size.")
if not recommendations:
recommendations.append("Cache performance is optimal.")
return recommendations
def print_stats_report(self):
"""打印统计报告"""
stats = self.get_detailed_stats()
print("=" * 50)
print("CACHE PERFORMANCE REPORT")
print("=" * 50)
# 基本统计
print(f"Backend Type: {stats['backend_type']}")
print(f"Cache Size: {stats['size']}/{stats.get('max_size', 'unlimited')}")
print(f"Hit Rate: {stats['hit_rate']:.2%}")
print(f"Total Hits: {stats['hits']}")
print(f"Total Misses: {stats['misses']}")
print(f"Evictions: {stats['evictions']}")
# 效率分析
efficiency = stats['efficiency']
print(f"\nEfficiency Score: {efficiency['efficiency_score']}")
print(f"Total Requests: {efficiency['total_requests']}")
# 内存压力
memory = stats['memory_pressure']
print(f"\nMemory Utilization: {memory['utilization']:.1%}")
print(f"Pressure Level: {memory['pressure_level']}")
# 建议
print(f"\nRecommendations:")
for rec in stats['recommendations']:
print(f" - {rec}")
print("=" * 50)
# 使用示例和测试
if __name__ == "__main__":
print("Testing Cache Manager...")
# 1. 测试内存缓存
print("\n1. Testing Memory Cache:")
memory_cache = CacheManager(backend='memory', max_size=100, default_ttl=10)
# 设置缓存
memory_cache.set('test_key', {'data': 'test_value', 'number': 42})
memory_cache.set('another_key', [1, 2, 3, 4, 5])
# 获取缓存
cached_data = memory_cache.get('test_key')
print(f"Cached data: {cached_data}")
# 测试DataFrame缓存
test_df = pd.DataFrame({
'A': [1, 2, 3],
'B': ['x', 'y', 'z'],
'C': [1.1, 2.2, 3.3]
})
memory_cache.cache_dataframe('test_df', test_df)
retrieved_df = memory_cache.get_dataframe('test_df')
print(f"DataFrame cached and retrieved successfully: {retrieved_df is not None}")
# 2. 测试文件缓存
print("\n2. Testing File Cache:")
file_cache = CacheManager(backend='file', cache_dir='./test_cache')
file_cache.set('persistent_key', {'message': 'This will survive restarts'})
persistent_data = file_cache.get('persistent_key')
print(f"Persistent data: {persistent_data}")
# 3. 测试缓存装饰器
print("\n3. Testing Cache Decorator:")
@cached(ttl=5)
def expensive_computation(n):
print(f"Computing factorial of {n}...")
result = 1
for i in range(1, n + 1):
result *= i
return result
# 第一次调用
result1 = expensive_computation(5)
print(f"Result 1: {result1}")
# 第二次调用(应该从缓存获取)
result2 = expensive_computation(5)
print(f"Result 2: {result2}")
# 4. 测试缓存统计
print("\n4. Cache Statistics:")
stats_analyzer = CacheStats(memory_cache)
stats_analyzer.print_stats_report()
# 5. 清理测试
print("\n5. Cleanup:")
memory_cache.clear()
file_cache.clear()
# 清理测试文件
import shutil
test_cache_dir = Path('./test_cache')
if test_cache_dir.exists():
shutil.rmtree(test_cache_dir)
print("Cache manager tests completed!")