File size: 7,645 Bytes
e4fd6e0
 
 
 
 
 
 
5105d0e
e4fd6e0
 
5105d0e
 
 
 
e4fd6e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dbcd4
 
e4fd6e0
 
 
 
 
 
 
 
 
 
 
 
 
1499a5e
e4fd6e0
19dbcd4
e4fd6e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5105d0e
e4fd6e0
 
 
 
 
 
 
 
 
 
 
 
 
5105d0e
e4fd6e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5105d0e
e4fd6e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""
Security utilities: headers, CSRF protection, input validation
"""
import os
import logging
from flask import request
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo

logger = logging.getLogger(__name__)
IST = ZoneInfo("Asia/Kolkata")

def now_ist() -> datetime:
    return datetime.now(IST).replace(tzinfo=None)


def init_security(app):
    """Initialize security features for Flask app"""
    
    @app.before_request
    def set_security_headers():
        """Add security headers to all responses"""
        pass  # Headers are set in after_request
    
    @app.after_request
    def add_security_headers(response):
        """Add security headers to all responses"""
        
        # Prevent clickjacking attacks
        # X-Frame-Options is removed because it prevents Hugging Face from embedding the app.
        # We rely on CSP frame-ancestors instead.
        
        # Prevent MIME type sniffing
        response.headers['X-Content-Type-Options'] = 'nosniff'
        
        # Enable XSS protection in older browsers
        response.headers['X-XSS-Protection'] = '1; mode=block'
        
        # Content Security Policy - restrictive but functional
        csp = (
            "default-src 'self'; "
            "script-src 'self' 'unsafe-inline'; "  # Minimal unsafe-inline for compatibility
            "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; "
            "font-src 'self' https://fonts.gstatic.com; "
            "img-src 'self' data: https://res.cloudinary.com; "
            "connect-src 'self'; "
            "frame-ancestors 'self' https://huggingface.co https://*.hf.space; "
            "base-uri 'self'; "
            "form-action 'self'"
        )
        response.headers['Content-Security-Policy'] = csp
        
        # Referrer policy
        response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
        
        # Feature policy / Permissions policy
        response.headers['Permissions-Policy'] = (
            'geolocation=(), microphone=(), camera=(), usb=(), payment=()'
        )
        
        # HSTS (HTTP Strict-Transport-Security) - only on HTTPS
        if request.is_secure or os.environ.get('FLASK_ENV') == 'production':
            response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
        
        return response
    
    # Session security
    app.config.update(
        SESSION_COOKIE_SECURE=os.environ.get('FLASK_ENV') == 'production',
        SESSION_COOKIE_HTTPONLY=True,
        SESSION_COOKIE_SAMESITE='Lax',
        PERMANENT_SESSION_LIFETIME=timedelta(days=30),
    )
    
    logger.info("Security headers and features initialized")


def sanitize_filename(filename: str, max_length: int = 255) -> str:
    """
    Sanitize filename to prevent directory traversal and other attacks.
    
    Args:
        filename: The filename to sanitize
        max_length: Maximum length for the sanitized filename
    
    Returns:
        Safe filename
    """
    import re
    
    # Remove any path components
    filename = os.path.basename(filename)
    
    # Remove null bytes
    filename = filename.replace('\0', '')
    
    # Allow only safe characters (alphanumeric, dash, underscore, dot)
    filename = re.sub(r'[^\w\-\.]', '_', filename)
    
    # Remove leading/trailing dots and spaces
    filename = filename.strip('. ')
    
    # Prevent empty filename
    if not filename:
        filename = 'file'
    
    # Limit length
    if len(filename) > max_length:
        # Preserve extension
        name, ext = os.path.splitext(filename)
        filename = name[:max_length - len(ext)] + ext
    
    return filename


def validate_file_extension(filename: str, allowed_extensions: list) -> bool:
    """
    Validate that file has an allowed extension.
    
    Args:
        filename: The filename to validate
        allowed_extensions: List of allowed extensions (without dots)
    
    Returns:
        True if extension is allowed, False otherwise
    """
    if not filename or '.' not in filename:
        return False
    
    ext = filename.rsplit('.', 1)[-1].lower()
    return ext in [e.lower() for e in allowed_extensions]


def mask_sensitive_data(data: dict, fields_to_mask: list) -> dict:
    """
    Mask sensitive fields in a dictionary before logging or sending to client.
    
    Args:
        data: Dictionary containing data to mask
        fields_to_mask: List of field names to mask
    
    Returns:
        Dictionary with masked fields
    """
    import copy
    
    masked = copy.deepcopy(data)
    for field in fields_to_mask:
        if field in masked:
            value = str(masked[field])
            if len(value) > 4:
                masked[field] = value[:2] + '*' * (len(value) - 4) + value[-2:]
            else:
                masked[field] = '*' * len(value)
    
    return masked


def get_client_info() -> dict:
    """Extract client information from request for logging"""
    return {
        'ip_address': request.remote_addr,
        'user_agent': request.headers.get('User-Agent', 'Unknown'),
        'endpoint': request.endpoint,
        'method': request.method,
        'timestamp': now_ist().isoformat()
    }


class RateLimiter:
    """Simple in-memory rate limiter for protecting against abuse"""
    
    def __init__(self, max_requests: int = 100, window_seconds: int = 60):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.requests = {}  # {key: [(timestamp, count), ...]}
    
    def is_rate_limited(self, key: str) -> bool:
        """Check if a key has exceeded rate limit"""
        now = now_ist()
        window_start = now - timedelta(seconds=self.window_seconds)
        
        # Clean old entries
        if key in self.requests:
            self.requests[key] = [
                (ts, count) for ts, count in self.requests[key]
                if ts > window_start
            ]
        
        # Count requests in window
        total_requests = sum(count for _, count in self.requests.get(key, []))
        
        return total_requests >= self.max_requests
    
    def record_request(self, key: str):
        """Record a request for rate limiting"""
        now = now_ist()
        
        if key not in self.requests:
            self.requests[key] = []
        
        # Add or increment the count for this second
        if self.requests[key] and self.requests[key][-1][0] == now:
            ts, count = self.requests[key][-1]
            self.requests[key][-1] = (ts, count + 1)
        else:
            self.requests[key].append((now, 1))


# Global rate limiter instances
login_rate_limiter = RateLimiter(max_requests=5, window_seconds=300)  # 5 attempts in 5 minutes
upload_rate_limiter = RateLimiter(max_requests=20, window_seconds=3600)  # 20 uploads per hour


def check_login_rate_limit(identifier: str) -> tuple[bool, str]:
    """
    Check if login attempt should be rate limited.
    Returns (is_limited, message)
    """
    if login_rate_limiter.is_rate_limited(identifier):
        return True, "Too many login attempts. Please try again later."
    
    login_rate_limiter.record_request(identifier)
    return False, ""


def check_upload_rate_limit(user_id: int) -> tuple[bool, str]:
    """
    Check if upload should be rate limited.
    Returns (is_limited, message)
    """
    key = f"upload_{user_id}"
    if upload_rate_limiter.is_rate_limited(key):
        return True, "Upload rate limit exceeded. Maximum 20 uploads per hour."
    
    upload_rate_limiter.record_request(key)
    return False, ""