#include #import #import #ifdef EMBEDDED_METALLIB_HEADER #include EMBEDDED_METALLIB_HEADER #else #error "EMBEDDED_METALLIB_HEADER not defined" #endif static inline id getMTLBufferStorage(const torch::Tensor &tensor) { return __builtin_bit_cast(id, tensor.storage().data()); } void first_kernel(torch::Tensor &out, torch::Tensor const &input) { TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor"); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, "first_kernel only supports float32"); TORCH_CHECK(input.sizes() == out.sizes(), "Tensors must have same shape"); TORCH_CHECK(input.scalar_type() == out.scalar_type(), "Tensors must have same dtype"); TORCH_CHECK(input.device() == out.device(), "Tensors must be on same device"); @autoreleasepool { id device = MTLCreateSystemDefaultDevice(); int numThreads = input.numel(); NSError *error = nil; id library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error); TORCH_CHECK(library, "Failed to create Metal library: ", error.localizedDescription.UTF8String); id func = [library newFunctionWithName:@"first_kernel_kernel"]; TORCH_CHECK(func, "Failed to create function"); id pso = [device newComputePipelineStateWithFunction:func error:&error]; TORCH_CHECK(pso, error.localizedDescription.UTF8String); id cmdBuf = torch::mps::get_command_buffer(); dispatch_sync(torch::mps::get_dispatch_queue(), ^() { id encoder = [cmdBuf computeCommandEncoder]; [encoder setComputePipelineState:pso]; [encoder setBuffer:getMTLBufferStorage(input) offset:input.storage_offset() * input.element_size() atIndex:0]; [encoder setBuffer:getMTLBufferStorage(out) offset:out.storage_offset() * out.element_size() atIndex:1]; NSUInteger tgSize = MIN(pso.maxTotalThreadsPerThreadgroup, (NSUInteger)numThreads); [encoder dispatchThreads:MTLSizeMake(numThreads, 1, 1) threadsPerThreadgroup:MTLSizeMake(tgSize, 1, 1)]; [encoder endEncoding]; torch::mps::commit(); }); } }