drbh commited on
Commit ·
aca891a
1
Parent(s): 4148918
fix: improve python bindings and sanity check
Browse files- flake.lock +12 -12
- scripts/sanity.py +5 -9
- torch-ext/img2gray/__init__.py +6 -10
- torch-ext/torch_binding.cpp +7 -8
flake.lock
CHANGED
|
@@ -73,11 +73,11 @@
|
|
| 73 |
"nixpkgs": "nixpkgs"
|
| 74 |
},
|
| 75 |
"locked": {
|
| 76 |
-
"lastModified":
|
| 77 |
-
"narHash": "sha256-
|
| 78 |
"owner": "huggingface",
|
| 79 |
"repo": "hf-nix",
|
| 80 |
-
"rev": "
|
| 81 |
"type": "github"
|
| 82 |
},
|
| 83 |
"original": {
|
|
@@ -98,11 +98,11 @@
|
|
| 98 |
]
|
| 99 |
},
|
| 100 |
"locked": {
|
| 101 |
-
"lastModified":
|
| 102 |
-
"narHash": "sha256-
|
| 103 |
"owner": "huggingface",
|
| 104 |
"repo": "kernel-builder",
|
| 105 |
-
"rev": "
|
| 106 |
"type": "github"
|
| 107 |
},
|
| 108 |
"original": {
|
|
@@ -113,17 +113,17 @@
|
|
| 113 |
},
|
| 114 |
"nixpkgs": {
|
| 115 |
"locked": {
|
| 116 |
-
"lastModified":
|
| 117 |
-
"narHash": "sha256-
|
| 118 |
-
"owner": "
|
| 119 |
"repo": "nixpkgs",
|
| 120 |
-
"rev": "
|
| 121 |
"type": "github"
|
| 122 |
},
|
| 123 |
"original": {
|
| 124 |
-
"owner": "
|
| 125 |
-
"ref": "cudatoolkit-12.9-kernel-builder",
|
| 126 |
"repo": "nixpkgs",
|
|
|
|
| 127 |
"type": "github"
|
| 128 |
}
|
| 129 |
},
|
|
|
|
| 73 |
"nixpkgs": "nixpkgs"
|
| 74 |
},
|
| 75 |
"locked": {
|
| 76 |
+
"lastModified": 1754038838,
|
| 77 |
+
"narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
|
| 78 |
"owner": "huggingface",
|
| 79 |
"repo": "hf-nix",
|
| 80 |
+
"rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
|
| 81 |
"type": "github"
|
| 82 |
},
|
| 83 |
"original": {
|
|
|
|
| 98 |
]
|
| 99 |
},
|
| 100 |
"locked": {
|
| 101 |
+
"lastModified": 1755181472,
|
| 102 |
+
"narHash": "sha256-xOXjhehC5xi/XB4fXZ5c0L2sSyDjJQdlH7/BcdHLBaM=",
|
| 103 |
"owner": "huggingface",
|
| 104 |
"repo": "kernel-builder",
|
| 105 |
+
"rev": "85da46f660c1c43b40771c3df3b223bb3fa39bec",
|
| 106 |
"type": "github"
|
| 107 |
},
|
| 108 |
"original": {
|
|
|
|
| 113 |
},
|
| 114 |
"nixpkgs": {
|
| 115 |
"locked": {
|
| 116 |
+
"lastModified": 1752785354,
|
| 117 |
+
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
|
| 118 |
+
"owner": "nixos",
|
| 119 |
"repo": "nixpkgs",
|
| 120 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
| 121 |
"type": "github"
|
| 122 |
},
|
| 123 |
"original": {
|
| 124 |
+
"owner": "nixos",
|
|
|
|
| 125 |
"repo": "nixpkgs",
|
| 126 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
| 127 |
"type": "github"
|
| 128 |
}
|
| 129 |
},
|
scripts/sanity.py
CHANGED
|
@@ -6,18 +6,14 @@ import numpy as np
|
|
| 6 |
|
| 7 |
print(dir(img2gray))
|
| 8 |
|
| 9 |
-
img = Image.open("
|
| 10 |
img = np.array(img)
|
| 11 |
-
img_tensor = torch.from_numpy(img)
|
| 12 |
print(img_tensor.shape) # HWC
|
| 13 |
-
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).contiguous().cuda() # BCHW
|
| 14 |
-
print(img_tensor.shape) # BCHW
|
| 15 |
|
| 16 |
gray_tensor = img2gray.img2gray(img_tensor).squeeze()
|
| 17 |
-
print(gray_tensor.shape) #
|
| 18 |
|
| 19 |
# save the output image
|
| 20 |
-
gray_img = gray_tensor.cpu().numpy()
|
| 21 |
-
gray_img
|
| 22 |
-
|
| 23 |
-
gray_img.save("/home/ubuntu/Projects/img2gray/kernel-builder-logo-gray.png")
|
|
|
|
| 6 |
|
| 7 |
print(dir(img2gray))
|
| 8 |
|
| 9 |
+
img = Image.open("kernel-builder-logo-color.png").convert("RGB")
|
| 10 |
img = np.array(img)
|
| 11 |
+
img_tensor = torch.from_numpy(img).cuda()
|
| 12 |
print(img_tensor.shape) # HWC
|
|
|
|
|
|
|
| 13 |
|
| 14 |
gray_tensor = img2gray.img2gray(img_tensor).squeeze()
|
| 15 |
+
print(gray_tensor.shape) # HW
|
| 16 |
|
| 17 |
# save the output image
|
| 18 |
+
gray_img = Image.fromarray(gray_tensor.cpu().numpy().astype(np.uint8), mode="L")
|
| 19 |
+
gray_img.save("kernel-builder-logo-gray.png")
|
|
|
|
|
|
torch-ext/img2gray/__init__.py
CHANGED
|
@@ -2,17 +2,13 @@ import torch
|
|
| 2 |
|
| 3 |
from ._ops import ops
|
| 4 |
|
| 5 |
-
def img2gray(input: torch.Tensor) -> torch.Tensor:
|
| 6 |
-
# we expect input to be in BCHW format
|
| 7 |
-
batch, channels, height, width = input.shape
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
assert channels == 3, "Input image must have 3 channels (RGB)"
|
| 10 |
|
| 11 |
-
output = torch.empty((
|
| 12 |
-
|
| 13 |
-
for b in range(batch):
|
| 14 |
-
single_image = input[b].permute(1, 2, 0).contiguous() # HWC
|
| 15 |
-
single_output = output[b].reshape(height, width) # HW
|
| 16 |
-
ops.img2gray(single_image, single_output)
|
| 17 |
|
| 18 |
-
return output
|
|
|
|
| 2 |
|
| 3 |
from ._ops import ops
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
def img2gray(input: torch.Tensor) -> torch.Tensor:
|
| 7 |
+
# we expect input to be in CHW format
|
| 8 |
+
height, width, channels = input.shape
|
| 9 |
assert channels == 3, "Input image must have 3 channels (RGB)"
|
| 10 |
|
| 11 |
+
output = torch.empty((height, width), device=input.device, dtype=input.dtype)
|
| 12 |
+
ops.img2gray(input, output)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
return output
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
|
|
| 1 |
#include <torch/library.h>
|
| 2 |
-
|
| 3 |
-
#include "
|
| 4 |
-
#include "torch_binding.h"
|
| 5 |
-
|
| 6 |
|
| 7 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 8 |
-
|
| 9 |
-
|
| 10 |
}
|
| 11 |
-
|
| 12 |
-
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
|
| 1 |
+
// torch-ext/torch_binding.cpp
|
| 2 |
#include <torch/library.h>
|
| 3 |
+
#include "registration.h" // included in the build
|
| 4 |
+
#include "torch_binding.h" // Declares our img2gray_cuda function
|
|
|
|
|
|
|
| 5 |
|
| 6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
ops.def("img2gray(Tensor input, Tensor! output) -> ()");
|
| 8 |
+
ops.impl("img2gray", torch::kCUDA, &img2gray_cuda);
|
| 9 |
}
|
| 10 |
+
|
| 11 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|