paijo77 commited on
Commit
2e7b8f4
·
verified ·
1 Parent(s): 85fd511

update app/source_validator.py

Browse files
Files changed (1) hide show
  1. app/source_validator.py +145 -0
app/source_validator.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aiohttp
2
+ import asyncio
3
+ import socket
4
+ import ipaddress
5
+ from urllib.parse import urlparse
6
+ from typing import Optional, List
7
+ from pydantic import BaseModel
8
+
9
+ from app.models import SourceConfig, SourceType
10
+ from app.grabber import GitHubGrabber
11
+
12
+
13
+ class SourceValidationResult(BaseModel):
14
+ valid: bool
15
+ error_message: Optional[str] = None
16
+ proxy_count: int = 0
17
+ sample_proxies: List[str] = []
18
+
19
+
20
+ class SourceValidator:
21
+ def __init__(self, timeout: int = 15):
22
+ self.timeout = aiohttp.ClientTimeout(total=timeout)
23
+ self.grabber = GitHubGrabber()
24
+
25
+ def is_internal_url(self, url: str) -> bool:
26
+ """Check if the URL points to an internal network (SSRF protection)."""
27
+ try:
28
+ parsed = urlparse(url)
29
+ hostname = parsed.hostname
30
+ if not hostname:
31
+ return True
32
+
33
+ # Check for common internal hostnames
34
+ if hostname.lower() in ["localhost", "127.0.0.1", "::1", "0.0.0.0"]:
35
+ return True
36
+
37
+ # Resolve hostname to IP and check if it's private
38
+ addr_info = socket.getaddrinfo(hostname, None)
39
+ for item in addr_info:
40
+ ip = item[4][0]
41
+ if ipaddress.ip_address(ip).is_private:
42
+ return True
43
+ if ipaddress.ip_address(ip).is_loopback:
44
+ return True
45
+ if ipaddress.ip_address(ip).is_link_local:
46
+ return True
47
+
48
+ return False
49
+ except Exception:
50
+ # If resolution fails, we'll treat it as potentially unsafe or handle it in reachable check
51
+ return False
52
+
53
+ async def validate_url_reachable(self, url: str) -> tuple[bool, Optional[str]]:
54
+ if self.is_internal_url(url):
55
+ return False, "Access to internal networks is restricted (SSRF protection)"
56
+
57
+ try:
58
+ async with aiohttp.ClientSession(timeout=self.timeout) as session:
59
+ async with session.get(url, ssl=False) as resp:
60
+ if resp.status == 200:
61
+ content_type = resp.headers.get("Content-Type", "")
62
+ content = await resp.text()
63
+
64
+ if len(content) < 10:
65
+ return False, "Source content too short (< 10 characters)"
66
+
67
+ if len(content) > 50_000_000:
68
+ return False, "Source content too large (> 50MB)"
69
+
70
+ return True, None
71
+ elif resp.status == 404:
72
+ return False, "Source not found (404)"
73
+ elif resp.status == 403:
74
+ return False, "Access forbidden (403)"
75
+ elif resp.status >= 500:
76
+ return False, f"Server error ({resp.status})"
77
+ else:
78
+ return False, f"HTTP error {resp.status}"
79
+
80
+ except asyncio.TimeoutError:
81
+ return False, "Connection timeout - source took too long to respond"
82
+ except aiohttp.ClientConnectorError:
83
+ return False, "Cannot connect to source URL"
84
+ except Exception as e:
85
+ return False, f"Error: {str(e)[:100]}"
86
+
87
+ async def validate_source_format(
88
+ self, source: SourceConfig
89
+ ) -> tuple[bool, Optional[str]]:
90
+ url_str = str(source.url)
91
+
92
+ if source.type == SourceType.GITHUB_RAW:
93
+ if "github.com" not in url_str:
94
+ return False, "GitHub source must contain 'github.com'"
95
+ if "/raw/" not in url_str and "githubusercontent.com" not in url_str:
96
+ return False, "GitHub source must be a raw file URL"
97
+
98
+ elif source.type == SourceType.SUBSCRIPTION_BASE64:
99
+ if not url_str.startswith(("http://", "https://")):
100
+ return False, "Subscription source must start with http:// or https://"
101
+
102
+ return True, None
103
+
104
+ async def test_proxy_extraction(
105
+ self, source: SourceConfig
106
+ ) -> tuple[int, List[str], Optional[str]]:
107
+ try:
108
+ proxies = await self.grabber.extract_proxies(source)
109
+
110
+ if not proxies:
111
+ return 0, [], "No proxies found in source"
112
+
113
+ proxy_urls = [p.url for p in proxies[:5]]
114
+ return len(proxies), proxy_urls, None
115
+
116
+ except Exception as e:
117
+ return 0, [], f"Failed to extract proxies: {str(e)[:100]}"
118
+
119
+ async def validate_source(self, source: SourceConfig) -> SourceValidationResult:
120
+ is_format_valid, format_error = await self.validate_source_format(source)
121
+ if not is_format_valid:
122
+ return SourceValidationResult(valid=False, error_message=format_error)
123
+
124
+ is_reachable, reachable_error = await self.validate_url_reachable(
125
+ str(source.url)
126
+ )
127
+ if not is_reachable:
128
+ return SourceValidationResult(valid=False, error_message=reachable_error)
129
+
130
+ (
131
+ proxy_count,
132
+ sample_proxies,
133
+ extraction_error,
134
+ ) = await self.test_proxy_extraction(source)
135
+ if extraction_error:
136
+ return SourceValidationResult(
137
+ valid=False, error_message=extraction_error, proxy_count=proxy_count
138
+ )
139
+
140
+ return SourceValidationResult(
141
+ valid=True, proxy_count=proxy_count, sample_proxies=sample_proxies
142
+ )
143
+
144
+
145
+ source_validator = SourceValidator()