| #include <torch/torch.h> |
|
|
| #import <Foundation/Foundation.h> |
| #import <Metal/Metal.h> |
|
|
| #ifdef EMBEDDED_METALLIB_HEADER |
| #include EMBEDDED_METALLIB_HEADER |
| #else |
| #error "EMBEDDED_METALLIB_HEADER not defined" |
| #endif |
|
|
| static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) { |
| return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data()); |
| } |
|
|
| void first_kernel(torch::Tensor &out, torch::Tensor const &input) { |
| TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor"); |
| TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); |
| TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, |
| "first_kernel only supports float32"); |
| TORCH_CHECK(input.sizes() == out.sizes(), "Tensors must have same shape"); |
| TORCH_CHECK(input.scalar_type() == out.scalar_type(), |
| "Tensors must have same dtype"); |
| TORCH_CHECK(input.device() == out.device(), |
| "Tensors must be on same device"); |
|
|
| @autoreleasepool { |
| id<MTLDevice> device = MTLCreateSystemDefaultDevice(); |
| int numThreads = input.numel(); |
|
|
| NSError *error = nil; |
| id<MTLLibrary> library = |
| EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error); |
| TORCH_CHECK(library, "Failed to create Metal library: ", |
| error.localizedDescription.UTF8String); |
|
|
| id<MTLFunction> func = |
| [library newFunctionWithName:@"first_kernel_kernel"]; |
| TORCH_CHECK(func, "Failed to create function"); |
|
|
| id<MTLComputePipelineState> pso = |
| [device newComputePipelineStateWithFunction:func error:&error]; |
| TORCH_CHECK(pso, error.localizedDescription.UTF8String); |
|
|
| id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer(); |
| dispatch_sync(torch::mps::get_dispatch_queue(), ^() { |
| id<MTLComputeCommandEncoder> encoder = [cmdBuf computeCommandEncoder]; |
| [encoder setComputePipelineState:pso]; |
| [encoder setBuffer:getMTLBufferStorage(input) |
| offset:input.storage_offset() * input.element_size() |
| atIndex:0]; |
| [encoder setBuffer:getMTLBufferStorage(out) |
| offset:out.storage_offset() * out.element_size() |
| atIndex:1]; |
|
|
| NSUInteger tgSize = |
| MIN(pso.maxTotalThreadsPerThreadgroup, (NSUInteger)numThreads); |
| [encoder dispatchThreads:MTLSizeMake(numThreads, 1, 1) |
| threadsPerThreadgroup:MTLSizeMake(tgSize, 1, 1)]; |
| [encoder endEncoding]; |
| torch::mps::commit(); |
| }); |
| } |
| } |
|
|