File size: 2,466 Bytes
f9791fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <torch/torch.h>

#import <Foundation/Foundation.h>
#import <Metal/Metal.h>

#ifdef EMBEDDED_METALLIB_HEADER
#include EMBEDDED_METALLIB_HEADER
#else
#error "EMBEDDED_METALLIB_HEADER not defined"
#endif

static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
  return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}

void first_kernel(torch::Tensor &out, torch::Tensor const &input) {
  TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
  TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
  TORCH_CHECK(input.scalar_type() == at::ScalarType::Float,
              "first_kernel only supports float32");
  TORCH_CHECK(input.sizes() == out.sizes(), "Tensors must have same shape");
  TORCH_CHECK(input.scalar_type() == out.scalar_type(),
              "Tensors must have same dtype");
  TORCH_CHECK(input.device() == out.device(),
              "Tensors must be on same device");

  @autoreleasepool {
    id<MTLDevice> device = MTLCreateSystemDefaultDevice();
    int numThreads = input.numel();

    NSError *error = nil;
    id<MTLLibrary> library =
        EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
    TORCH_CHECK(library, "Failed to create Metal library: ",
                error.localizedDescription.UTF8String);

    id<MTLFunction> func =
        [library newFunctionWithName:@"first_kernel_kernel"];
    TORCH_CHECK(func, "Failed to create function");

    id<MTLComputePipelineState> pso =
        [device newComputePipelineStateWithFunction:func error:&error];
    TORCH_CHECK(pso, error.localizedDescription.UTF8String);

    id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer();
    dispatch_sync(torch::mps::get_dispatch_queue(), ^() {
      id<MTLComputeCommandEncoder> encoder = [cmdBuf computeCommandEncoder];
      [encoder setComputePipelineState:pso];
      [encoder setBuffer:getMTLBufferStorage(input)
                  offset:input.storage_offset() * input.element_size()
                 atIndex:0];
      [encoder setBuffer:getMTLBufferStorage(out)
                  offset:out.storage_offset() * out.element_size()
                 atIndex:1];

      NSUInteger tgSize =
          MIN(pso.maxTotalThreadsPerThreadgroup, (NSUInteger)numThreads);
      [encoder dispatchThreads:MTLSizeMake(numThreads, 1, 1)
          threadsPerThreadgroup:MTLSizeMake(tgSize, 1, 1)];
      [encoder endEncoding];
      torch::mps::commit();
    });
  }
}