gate / user_manager.py
harii66's picture
Upload 23 files
b4edbc0 verified
import secrets
import hashlib
import json
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from pydantic import BaseModel
from upstash_redis import Redis
import os
from dotenv import load_dotenv
load_dotenv()
AVAILABLE_BADGES = {
'vip_dog': {
'id': 'vip_dog',
'name': '尊贵狗牌',
'icon': '🏷️',
'gradient': 'linear-gradient(135deg, #FFD700 0%, #FFA500 100%)',
'color': '#8B4513',
'border': '#FFD700',
'glow': 'rgba(255, 215, 0, 0.5)'
},
'diamond': {
'id': 'diamond',
'name': '钻石会员',
'icon': '💎',
'gradient': 'linear-gradient(135deg, #B9F2FF 0%, #00D4FF 100%)',
'color': '#003D5C',
'border': '#00D4FF',
'glow': 'rgba(0, 212, 255, 0.5)'
},
'crown': {
'id': 'crown',
'name': '皇冠用户',
'icon': '👑',
'gradient': 'linear-gradient(135deg, #FFE66D 0%, #FFB800 100%)',
'color': '#8B4000',
'border': '#FFB800',
'glow': 'rgba(255, 184, 0, 0.5)'
},
'star': {
'id': 'star',
'name': '星标用户',
'icon': '⭐',
'gradient': 'linear-gradient(135deg, #FFF7A5 0%, #FFDF00 100%)',
'color': '#8B7500',
'border': '#FFDF00',
'glow': 'rgba(255, 223, 0, 0.5)'
},
'fire': {
'id': 'fire',
'name': '火焰用户',
'icon': '🔥',
'gradient': 'linear-gradient(135deg, #FF6B6B 0%, #EE5A24 100%)',
'color': '#8B0000',
'border': '#EE5A24',
'glow': 'rgba(238, 90, 36, 0.5)'
},
'rocket': {
'id': 'rocket',
'name': '火箭用户',
'icon': '🚀',
'gradient': 'linear-gradient(135deg, #4ECDC4 0%, #44A08D 100%)',
'color': '#0D4C4A',
'border': '#44A08D',
'glow': 'rgba(68, 160, 141, 0.5)'
},
'rainbow': {
'id': 'rainbow',
'name': '彩虹用户',
'icon': '🌈',
'gradient': 'linear-gradient(135deg, #667eea 0%, #764ba2 50%, #f093fb 100%)',
'color': '#4A148C',
'border': '#764ba2',
'glow': 'rgba(118, 75, 162, 0.5)'
}
}
class User(BaseModel):
# 基本信息
username: str
password_hash: str
created_at: datetime
expires_at: Optional[datetime] = None
last_login: Optional[datetime] = None
is_active: bool = True
is_admin: bool = False
created_by: str = "admin"
notes: str = ""
badge: Optional[str] = None
# ✅ 用户设置(合并到用户模型)
favorite_channels: List[str] = []
download_concurrency: int = 16
batch_download_concurrency: int = 3
fab_position: Dict[str, float] = {'bottom': 30, 'right': 30}
playback_history: List[Dict] = []
program_reminders: List[Dict] = []
class Config:
json_encoders = {
datetime: lambda v: v.isoformat() if v else None
}
class UserManager:
def __init__(self):
redis_url = os.getenv('REDIS_URL', '')
redis_token = os.getenv('REDIS_TOKEN', '')
if not redis_url or not redis_token:
self.redis = None
self.users: Dict[str, User] = {}
else:
try:
self.redis = Redis(url=redis_url, token=redis_token)
self.redis.ping()
self.users: Dict[str, User] = {}
self.load_all_users()
except Exception as e:
self.redis = None
self.users: Dict[str, User] = {}
def _get_user_key(self, username: str) -> str:
return f"user:{username}"
def _save_user_to_redis(self, user: User):
"""保存用户到 Redis(包含所有设置)"""
if not self.redis:
return
try:
user_dict = user.dict()
# 转换 datetime
user_dict['created_at'] = user_dict['created_at'].isoformat()
if user_dict['expires_at']:
user_dict['expires_at'] = user_dict['expires_at'].isoformat()
if user_dict['last_login']:
user_dict['last_login'] = user_dict['last_login'].isoformat()
user_json = json.dumps(user_dict)
self.redis.set(self._get_user_key(user.username), user_json)
self.redis.sadd("users:all", user.username)
print(f"✅ 用户 {user.username} 已保存到 Redis(包含所有设置)")
except Exception as e:
print(f"❌ 保存用户失败: {e}")
def _save_user_to_redis_with_retry(self, user: User, max_retries: int = 3):
"""带重试机制的Redis保存"""
if not self.redis:
print(f"⚠️ Redis不可用,跳过保存用户 {user.username}")
return False
for attempt in range(max_retries):
try:
user_dict = user.dict()
# 转换 datetime
user_dict['created_at'] = user_dict['created_at'].isoformat()
if user_dict['expires_at']:
user_dict['expires_at'] = user_dict['expires_at'].isoformat()
if user_dict['last_login']:
user_dict['last_login'] = user_dict['last_login'].isoformat()
user_json = json.dumps(user_dict)
self.redis.set(self._get_user_key(user.username), user_json)
self.redis.sadd("users:all", user.username)
print(f"✅ 用户 {user.username} 已保存到 Redis(重试第 {attempt + 1} 次成功)")
return True
except Exception as e:
print(f"❌ 保存用户失败(第 {attempt + 1} 次重试): {e}")
if attempt == max_retries - 1:
print(f"❌ 用户 {user.username} 保存到 Redis 失败,已达最大重试次数")
return False
import time
time.sleep(0.5 * (attempt + 1)) # 指数退避
return False
def _load_user_from_redis(self, username: str) -> Optional[User]:
"""从 Redis 加载用户(包含所有设置)"""
if not self.redis:
return None
try:
user_json = self.redis.get(self._get_user_key(username))
if not user_json:
return None
user_dict = json.loads(user_json)
# 转换 datetime
user_dict['created_at'] = datetime.fromisoformat(user_dict['created_at'])
if user_dict.get('expires_at'):
user_dict['expires_at'] = datetime.fromisoformat(user_dict['expires_at'])
if user_dict.get('last_login'):
user_dict['last_login'] = datetime.fromisoformat(user_dict['last_login'])
# ✅ 兼容旧数据:如果没有新字段,使用默认值
if 'favorite_channels' not in user_dict:
user_dict['favorite_channels'] = []
if 'download_concurrency' not in user_dict:
user_dict['download_concurrency'] = 16
if 'batch_download_concurrency' not in user_dict:
user_dict['batch_download_concurrency'] = 3
if 'fab_position' not in user_dict:
user_dict['fab_position'] = {'bottom': 30, 'right': 30}
if 'playback_history' not in user_dict:
user_dict['playback_history'] = []
if 'program_reminders' not in user_dict:
user_dict['program_reminders'] = []
user = User(**user_dict)
return user
except Exception as e:
print(f"❌ 加载用户失败: {e}")
import traceback
traceback.print_exc()
return None
def load_all_users(self):
"""加载所有用户"""
if not self.redis:
return
try:
usernames = self.redis.smembers("users:all")
if not usernames:
return
for username in usernames:
user = self._load_user_from_redis(username)
if user:
self.users[username] = user
print(f"✅ 已加载 {len(self.users)} 个用户")
except Exception as e:
print(f"❌ 加载用户列表失败: {e}")
def _delete_user_from_redis(self, username: str):
"""从 Redis 删除用户"""
if not self.redis:
return
try:
self.redis.delete(self._get_user_key(username))
self.redis.srem("users:all", username)
except Exception as e:
pass
# ==================== 基本用户管理 ====================
def generate_password(self, length: int = 12) -> str:
chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
return ''.join(secrets.choice(chars) for _ in range(length))
def hash_password(self, password: str) -> str:
return hashlib.sha256(password.encode()).hexdigest()
def create_user(
self,
username: str,
password: Optional[str] = None,
expires_days: Optional[int] = None,
notes: str = "",
badge: Optional[str] = None,
is_admin: bool = False
) -> tuple[User, str]:
if username in self.users:
raise ValueError(f"User {username} already exists")
plain_password = password or self.generate_password()
password_hash = self.hash_password(plain_password)
expires_at = None
if expires_days:
expires_at = datetime.now() + timedelta(days=expires_days)
if badge and badge not in AVAILABLE_BADGES:
raise ValueError(f"Invalid badge: {badge}")
user = User(
username=username,
password_hash=password_hash,
created_at=datetime.now(),
expires_at=expires_at,
notes=notes,
badge=badge,
is_admin=is_admin
)
self.users[username] = user
# 实时保存到Redis
self._save_user_to_redis_with_retry(user)
return user, plain_password
def verify_user(self, username: str, password_hash: str) -> bool:
if username not in self.users:
user = self._load_user_from_redis(username)
if user:
self.users[username] = user
else:
return False
user = self.users[username]
if not user.is_active:
return False
if user.expires_at and datetime.now() > user.expires_at:
user.is_active = False
self._save_user_to_redis(user)
return False
if user.password_hash == password_hash:
user.last_login = datetime.now()
# 实时保存登录时间到Redis
self._save_user_to_redis_with_retry(user)
return True
return False
def delete_user(self, username: str) -> bool:
if username in self.users:
del self.users[username]
self._delete_user_from_redis(username)
return True
return False
def deactivate_user(self, username: str) -> bool:
if username in self.users:
self.users[username].is_active = False
# 实时保存状态变更到Redis
self._save_user_to_redis_with_retry(self.users[username])
return True
return False
def activate_user(self, username: str) -> bool:
if username in self.users:
self.users[username].is_active = True
# 实时保存状态变更到Redis
self._save_user_to_redis_with_retry(self.users[username])
return True
return False
def extend_expiry(self, username: str, days: int) -> bool:
if username in self.users:
user = self.users[username]
if user.expires_at:
user.expires_at += timedelta(days=days)
else:
user.expires_at = datetime.now() + timedelta(days=days)
# 实时保存过期时间变更到Redis
self._save_user_to_redis_with_retry(user)
return True
return False
def set_badge(self, username: str, badge: Optional[str]) -> bool:
if username not in self.users:
user = self._load_user_from_redis(username)
if user:
self.users[username] = user
else:
return False
if badge and badge not in AVAILABLE_BADGES:
raise ValueError(f"Invalid badge: {badge}")
self.users[username].badge = badge
# 实时保存徽章变更到Redis
self._save_user_to_redis_with_retry(self.users[username])
return True
# ==================== 用户设置管理(直接操作用户对象)====================
def get_user_data(self, username: str) -> Optional[Dict]:
"""获取用户完整数据(包含设置)"""
if username not in self.users:
user = self._load_user_from_redis(username)
if user:
self.users[username] = user
else:
return None
user = self.users[username]
return {
'favorite_channels': user.favorite_channels,
'download_concurrency': user.download_concurrency,
'batch_download_concurrency': user.batch_download_concurrency,
'fab_position': user.fab_position,
'playback_history': user.playback_history,
'program_reminders': user.program_reminders
}
def get_user_settings(self, username: str) -> Dict:
"""获取用户设置(兼容旧API)"""
data = self.get_user_data(username)
if data is None:
# 如果用户不存在,返回默认设置
return {
'favorite_channels': [],
'download_concurrency': 16,
'batch_download_concurrency': 3,
'fab_position': {'bottom': 30, 'right': 30},
'playback_history': [],
'program_reminders': []
}
return data
def delete_user_settings(self, username: str) -> bool:
"""删除用户设置(重置为默认值)"""
if username not in self.users:
user = self._load_user_from_redis(username)
if user:
self.users[username] = user
else:
return False
# 重置用户设置为默认值
user = self.users[username]
user.favorite_channels = []
user.download_concurrency = 16
user.batch_download_concurrency = 3
user.fab_position = {'bottom': 30, 'right': 30}
user.playback_history = []
user.program_reminders = []
# 实时保存到Redis
self._save_user_to_redis_with_retry(user)
print(f"✅ 用户 {username} 设置已重置为默认值")
return True
def update_user_data(self, username: str, data: Dict) -> bool:
"""更新用户数据(增量更新)"""
if username not in self.users:
user = self._load_user_from_redis(username)
if user:
self.users[username] = user
else:
return False
user = self.users[username]
# ✅ 只更新传入的字段
if 'favorite_channels' in data:
user.favorite_channels = data['favorite_channels']
if 'download_concurrency' in data:
user.download_concurrency = data['download_concurrency']
if 'batch_download_concurrency' in data:
user.batch_download_concurrency = data['batch_download_concurrency']
if 'fab_position' in data:
user.fab_position = data['fab_position']
if 'playback_history' in data:
user.playback_history = data['playback_history']
if 'program_reminders' in data:
user.program_reminders = data['program_reminders']
# 实时保存用户行为数据到Redis
self._save_user_to_redis_with_retry(user)
print(f"✅ 用户 {username} 数据已实时保存到Redis: {list(data.keys())}")
return True
# ==================== 便捷方法 ====================
def get_favorites(self, username: str) -> List[str]:
"""获取收藏频道"""
data = self.get_user_data(username)
return data['favorite_channels'] if data else []
def set_favorites(self, username: str, favorites: List[str]) -> bool:
"""设置收藏频道"""
return self.update_user_data(username, {'favorite_channels': favorites})
def get_download_concurrency(self, username: str) -> int:
"""获取下载并发数"""
data = self.get_user_data(username)
return data['download_concurrency'] if data else 16
def set_download_concurrency(self, username: str, concurrency: int) -> bool:
"""设置下载并发数"""
return self.update_user_data(username, {'download_concurrency': concurrency})
def get_batch_concurrency(self, username: str) -> int:
"""获取批量并发数"""
data = self.get_user_data(username)
return data['batch_download_concurrency'] if data else 3
def set_batch_concurrency(self, username: str, concurrency: int) -> bool:
"""设置批量并发数"""
return self.update_user_data(username, {'batch_download_concurrency': concurrency})
def get_fab_position(self, username: str) -> Dict[str, float]:
"""获取 FAB 位置"""
data = self.get_user_data(username)
return data['fab_position'] if data else {'bottom': 30, 'right': 30}
def set_fab_position(self, username: str, position: Dict[str, float]) -> bool:
"""设置 FAB 位置"""
return self.update_user_data(username, {'fab_position': position})
def get_user(self, username: str) -> Optional[User]:
if username in self.users:
return self.users[username]
user = self._load_user_from_redis(username)
if user:
self.users[username] = user
return user
def list_users(self) -> List[User]:
try:
if self.redis:
self.load_all_users()
users = list(self.users.values())
return users
except Exception as e:
import traceback
traceback.print_exc()
return []
def get_stats(self) -> dict:
try:
if self.redis:
self.load_all_users()
total = len(self.users)
active = sum(1 for u in self.users.values() if u.is_active)
expired = sum(1 for u in self.users.values()
if u.expires_at and datetime.now() > u.expires_at)
return {
"total": total,
"active": active,
"expired": expired,
"inactive": total - active,
"storage": "Redis (Upstash)" if self.redis else "Memory (临时)"
}
except Exception as e:
return {
"total": 0,
"active": 0,
"expired": 0,
"inactive": 0,
"storage": "Error"
}
def get_available_badges(self) -> dict:
return AVAILABLE_BADGES
user_manager = UserManager()