File size: 2,106 Bytes
1af4cba
 
 
 
 
 
 
 
 
c4b0562
 
 
1af4cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ef2798
 
 
1af4cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
SalesPath β€” Pre-flight Dependency Check
Run at the start of training to catch version mismatches early.
"""
import sys
import importlib

REQUIRED_PACKAGES = {
    "torch": "2.0.0",
    "transformers": "4.44.0",
    "trl": "0.14.0",
    "peft": "0.11.0",
    "datasets": "2.0.0",
    "fastapi": "0.100.0",
    "httpx": "0.24.0",
    "openenv": None,
    "accelerate": "0.25.0",
}

all_ok = True

print("=" * 60)
print("SalesPath Pre-flight Check")
print("=" * 60)

# Python version
print(f"Python: {sys.version}")
if sys.version_info < (3, 10):
    print("  WARNING: Python >= 3.10 recommended")
    all_ok = False

# CUDA availability
try:
    import torch
    print(f"PyTorch: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        props = torch.cuda.get_device_properties(0)
        vram_gb = getattr(props, 'total_memory', getattr(props, 'total_mem', 0)) / 1e9
        print(f"VRAM: {vram_gb:.1f} GB")
except Exception as e:
    print(f"PyTorch: ERROR β€” {e}")
    all_ok = False

# Check each package
for pkg_name, min_version in REQUIRED_PACKAGES.items():
    try:
        mod = importlib.import_module(pkg_name)
        ver = getattr(mod, "__version__", "unknown")
        status = f"{ver}"
        if min_version:
            from packaging import version
            if version.parse(ver) < version.parse(min_version):
                status += f" (needs >= {min_version}) ⚠️"
                all_ok = False
            else:
                status += " βœ…"
        else:
            status += " βœ…"
        print(f"{pkg_name}: {status}")
    except ImportError:
        print(f"{pkg_name}: NOT FOUND ❌")
        all_ok = False
    except Exception as e:
        print(f"{pkg_name}: ERROR β€” {e} ❌")
        all_ok = False

print("=" * 60)
if all_ok:
    print("All checks passed βœ…")
else:
    print("Some checks failed ⚠️ β€” training may still work")
print("=" * 60)

sys.exit(0 if all_ok else 1)