Spaces:
Running on Zero
Running on Zero
File size: 3,024 Bytes
796e051 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import random
import io
import cv2
import numpy as np
from PIL import Image
import torch
cv2.setNumThreads(0)
# Setting
jpeg_quality_range = (70, 90) # Higher, Better Quality
webp_quality_range = (70, 90)
webp_encode_speed = (3, 5) # Higher, Slower, Better Quality
def _to_hwc_uint8(x: torch.Tensor) -> np.ndarray:
t = x.detach()
keep_batch_dim = False
if t.ndim == 4:
keep_batch_dim = True
if t.shape[0] != 1:
raise ValueError(f"Expect batch size 1 if 4D, got {tuple(t.shape)}")
t = t.squeeze(0)
if t.ndim != 3:
raise ValueError(f"Expect 3D tensor after squeeze, got {tuple(t.shape)}")
if t.shape[0] == 3:
t = t.permute(1, 2, 0) # CHW -> HWC
elif t.shape[2] == 3:
pass # already HWC
else:
raise ValueError(f"Expect CHW or HWC with 3 channels, got {tuple(t.shape)}")
if t.dtype != torch.uint8:
t = (t.clamp(0, 1) * 255.0).round().to(torch.uint8)
arr = t.cpu().numpy() # HWC uint8, usually contiguous
# ensure contiguous anyway
return np.ascontiguousarray(arr)
def _from_hwc_uint8(img: np.ndarray, keep_batch_dim: bool) -> torch.Tensor:
if img.ndim != 3 or img.shape[2] != 3:
raise ValueError(f"Expect (H,W,3), got {img.shape}")
if img.dtype != np.uint8:
img = np.clip(img, 0, 255).astype(np.uint8)
img = np.ascontiguousarray(img) # <-- critical for safety
t = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 # (3,H,W)
return t.unsqueeze(0) if keep_batch_dim else t
def jpeg_compress_tensor(tensor_frames: torch.Tensor) -> torch.Tensor:
keep_batch_dim = (tensor_frames.ndim == 4)
img_rgb = _to_hwc_uint8(tensor_frames)
# BGR for OpenCV (make contiguous to avoid negative strides)
img_bgr = np.ascontiguousarray(img_rgb[..., ::-1])
q = random.randint(jpeg_quality_range[0], jpeg_quality_range[1])
ok, enc = cv2.imencode(".jpg", img_bgr, [int(cv2.IMWRITE_JPEG_QUALITY), int(q)])
if not ok:
raise RuntimeError("cv2.imencode('.jpg') failed")
dec_bgr = cv2.imdecode(enc, cv2.IMREAD_COLOR)
if dec_bgr is None:
raise RuntimeError("cv2.imdecode() failed")
# back to RGB (again: contiguous!)
dec_rgb = np.ascontiguousarray(dec_bgr[..., ::-1])
return _from_hwc_uint8(dec_rgb, keep_batch_dim)
def webp_compress_tensor(tensor_frames: torch.Tensor) -> torch.Tensor:
keep_batch_dim = (tensor_frames.ndim == 4)
img_rgb = _to_hwc_uint8(tensor_frames)
quality = random.randint(webp_quality_range[0], webp_quality_range[1])
method = random.randint(webp_encode_speed[0], webp_encode_speed[1])
im = Image.fromarray(img_rgb, mode="RGB")
buf = io.BytesIO()
im.save(buf, format="WEBP", quality=int(quality), method=int(method))
data = buf.getvalue()
dec = np.array(Image.open(io.BytesIO(data)).convert("RGB"), dtype=np.uint8)
dec = np.ascontiguousarray(dec) # safety
return _from_hwc_uint8(dec, keep_batch_dim)
|