File size: 11,710 Bytes
5f3e9f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
"""Screenshot capture using Playwright with browser pooling."""
import atexit
from playwright.sync_api import sync_playwright
from PIL import Image
import tempfile
import time
import os
import io
import threading


class BrowserPool:
    """Thread-safe browser pool that keeps a Chromium instance per thread.

    Each Flask worker thread gets its own Playwright + Chromium pair via
    ``threading.local``. To support clean process shutdown we additionally
    keep a global registry (thread-id -> (playwright, browser)) so an
    ``atexit`` hook can close every Chromium even on threads that aren't
    currently executing β€” without it, ``shutdown()`` only ever cleaned up
    the thread that called it and we leaked Chromium processes.
    """

    def __init__(self):
        self._local = threading.local()
        self._lock = threading.Lock()
        # tid -> {"playwright": pw, "browser": br}
        self._registry: dict = {}

    def _ensure_browser(self):
        """Launch browser if not already running in this thread."""
        tid = threading.get_ident()
        if not hasattr(self._local, 'playwright') or self._local.playwright is None:
            self._local.playwright = sync_playwright().start()
            self._local.browser = self._local.playwright.chromium.launch(headless=True)
            self._registry[tid] = {
                'playwright': self._local.playwright,
                'browser': self._local.browser,
            }
            print(f"🌐 Browser launched for thread {threading.current_thread().name}")
        elif not self._local.browser.is_connected():
            self._local.browser = self._local.playwright.chromium.launch(headless=True)
            self._registry[tid] = {
                'playwright': self._local.playwright,
                'browser': self._local.browser,
            }
            print(f"🌐 Browser reconnected for thread {threading.current_thread().name}")

    def get_page(self, logical_width, logical_height, zoom):
        """Get a new page from the thread-local browser."""
        with self._lock:
            self._ensure_browser()
            page = self._local.browser.new_page(
                viewport={"width": logical_width, "height": logical_height},
                device_scale_factor=zoom,
            )
            return page

    def shutdown(self):
        """Clean up browser and playwright resources for current thread."""
        with self._lock:
            tid = threading.get_ident()
            if hasattr(self._local, 'browser') and self._local.browser:
                try:
                    self._local.browser.close()
                except Exception:
                    pass
                self._local.browser = None
            if hasattr(self._local, 'playwright') and self._local.playwright:
                try:
                    self._local.playwright.stop()
                except Exception:
                    pass
                self._local.playwright = None
            self._registry.pop(tid, None)
            print(
                f"🌐 Browser pool shut down for thread "
                f"{threading.current_thread().name}",
                flush=True,
            )

    def shutdown_all(self):
        """Close every Chromium instance we ever launched, regardless of thread.

        Called from an ``atexit`` hook on process exit. Closing a Playwright
        object created on another thread is technically off-pattern, but in
        practice the Playwright sync API tolerates it on shutdown and this
        prevents zombie Chromium processes.
        """
        with self._lock:
            for tid, refs in list(self._registry.items()):
                try:
                    if refs.get('browser') is not None:
                        refs['browser'].close()
                except Exception:
                    pass
                try:
                    if refs.get('playwright') is not None:
                        refs['playwright'].stop()
                except Exception:
                    pass
                self._registry.pop(tid, None)
            print("🌐 BrowserPool.shutdown_all: all Chromium instances closed.", flush=True)


# Module-level browser pool instance
_browser_pool = BrowserPool()


# Best-effort cleanup on process exit. Without this, hot-reload + Ctrl+C
# leak headless Chromium processes (one per worker thread that ever ran).
@atexit.register
def _shutdown_browser_pool():
    try:
        _browser_pool.shutdown_all()
    except Exception:
        pass


def get_browser_pool():
    """Get the module-level browser pool."""
    return _browser_pool


def take_screenshot_playwright(
    html_content,
    save_path,
    zoom=2.1,              # Single zoom parameter (replaces font_size + device_scale)
    overlap=20,            # Logical pixels of overlap between consecutive shots
    viewport_width=1920,   # Final output width in physical pixels
    viewport_height=1080,  # Final output height in physical pixels
    max_screenshots=50,
    progress_callback=None, # Optional callback for SSE progress updates
    cancel_event=None       # Optional threading.Event to abort generation
):
    """
    Take multiple 1920Γ—1080 screenshots of HTML content with configurable zoom.

    The zoom is achieved purely via device_scale_factor:
      - logical viewport = (1920/zoom) Γ— (1080/zoom)
      - device_scale_factor = zoom
      - output image = 1920 Γ— 1080 (no resize needed)

    Args:
        html_content:    Full HTML string to render
        save_path:       Path for the first screenshot (e.g., "output.png")
        zoom:            Zoom level (2.5 = 250%)
        overlap:         Overlap between shots in logical pixels
        viewport_width:  Output image width
        viewport_height: Output image height
        max_screenshots: Safety cap
        progress_callback: Optional callable(message, progress_pct) for progress updates
        cancel_event: Optional threading.Event to monitor for user cancellation
    """
    screenshots = []
    temp_fd = None
    temp_path = None
    page = None

    try:
        # Write HTML to unique temp file (fixes race condition)
        temp_fd, temp_path = tempfile.mkstemp(suffix='.html', prefix='screenshot_')
        with os.fdopen(temp_fd, 'w', encoding='utf-8') as f:
            f.write(html_content)
        temp_fd = None  # fd is now closed by os.fdopen

        file_url = f"file:///{os.path.abspath(temp_path)}"

        # ─── Single zoom mechanism ───
        logical_width = int(viewport_width / zoom)
        logical_height = int(viewport_height / zoom)

        # Clamp overlap below the logical viewport so (viewport_h - overlap)
        # in the num_est division is always positive. Without this, a user
        # sending overlap >= viewport_h causes ZeroDivisionError or negative
        # step-size crashes in the capture loop.
        max_overlap = max(0, logical_height - 1)
        if overlap > max_overlap:
            print(
                f"⚠️ overlap={overlap} exceeds logical viewport height "
                f"{logical_height}; clamping to {max_overlap}",
                flush=True,
            )
            overlap = max_overlap

        # Use pooled browser
        page = _browser_pool.get_page(logical_width, logical_height, zoom)

        page.goto(file_url)
        page.wait_for_load_state("networkidle")
        time.sleep(1)

        # Get page dimensions (clean, undistorted values)
        dimensions = page.evaluate("""
            () => ({
                scrollHeight: document.documentElement.scrollHeight,
                clientHeight: window.innerHeight
            })
        """)

        total_height = dimensions["scrollHeight"]
        viewport_h = dimensions["clientHeight"]

        print(f"πŸ“ Page: {total_height}px logical | Viewport: {viewport_h}px logical", flush=True)
        print(f"πŸ” Zoom: {zoom}x β†’ logical {logical_width}Γ—{logical_height} β†’ "
              f"output {viewport_width}Γ—{viewport_height}", flush=True)

        # Re-clamp against the real clientHeight (may differ from logical_height
        # when scrollbars take space) so the divisor below is always >= 1.
        if overlap >= viewport_h:
            overlap = max(0, viewport_h - 1)
        num_est = max(1, -(-total_height // (viewport_h - overlap)))
        print(f"πŸ“Έ Estimated {num_est} screenshot(s) (overlap={overlap}px)", flush=True)

        if progress_callback:
            progress_callback(f"Estimated {num_est} screenshot(s)", 10)

        # Prepare filenames
        base_path, extension = (
            (save_path.rsplit(".", 1)[0], save_path.rsplit(".", 1)[1])
            if "." in save_path
            else (save_path, "png")
        )

        screenshot_count = 0
        scroll_position = 0

        while screenshot_count < max_screenshots:
            # Check for early cancellation
            if cancel_event and cancel_event.is_set():
                print("πŸ›‘ Screenshot generation aborted by user.")
                break
                
            screenshot_count += 1
            
            # Scroll to position
            page.evaluate(f"window.scrollTo(0, {scroll_position})")
            time.sleep(0.3)

            actual_scroll = page.evaluate("window.pageYOffset")
            current_total = page.evaluate(
                "document.documentElement.scrollHeight"
            )

            print(
                f"πŸ“ Shot {screenshot_count}: "
                f"scroll={actual_scroll}/{current_total - viewport_h}", flush=True
            )

            # ─── Consistent naming: base(1).png, base(2).png, ... ───
            screenshot_path = f"{base_path}({screenshot_count}).{extension}"

            # Capture screenshot
            screenshot_bytes = page.screenshot(full_page=False)
            img = Image.open(io.BytesIO(screenshot_bytes))

            # Verify exact target size
            target_size = (viewport_width, viewport_height)
            if img.size != target_size:
                print(f"   ⚠️  Raw {img.size} β†’ resizing to {target_size}")
                img = img.resize(target_size, Image.Resampling.LANCZOS)

            img.save(screenshot_path, "PNG")
            screenshots.append(screenshot_path)
            print(
                f"   βœ… {os.path.basename(screenshot_path)} "
                f"({img.size[0]}Γ—{img.size[1]})", flush=True
            )

            if progress_callback:
                pct = 10 + int((screenshot_count / max(num_est, 1)) * 80)
                progress_callback(
                    f"Captured screenshot {screenshot_count}/{num_est}",
                    min(pct, 90)
                )

            # ─── Check if we've reached the bottom ───
            max_scroll = current_total - viewport_h

            if actual_scroll >= max_scroll - 2:
                print("🏁 Reached bottom of page.", flush=True)
                break

            # Advance with overlap
            scroll_position += viewport_h - overlap

            # Clamp to capture the very last strip
            if scroll_position > max_scroll:
                scroll_position = max_scroll

        print(f"\nβœ… Done β€” {screenshot_count} screenshot(s) saved.")

    finally:
        # Close page (but browser stays alive in pool)
        if page:
            try:
                page.close()
            except Exception:
                pass
        
        # Clean up temp file
        if temp_path and os.path.exists(temp_path):
            os.remove(temp_path)

    return screenshots


# Alias for compatibility
take_screenshot_selenium = take_screenshot_playwright