Skip to content

Commit 0064862

Browse files
committed
Add native strided API for MPSNDArray
1 parent 5b5d269 commit 0064862

19 files changed

+486
-92
lines changed

aten/src/ATen/mps/MPSDevice.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ enum class MacOSVersion : uint32_t {
3131
MACOS_VER_13_2_PLUS,
3232
MACOS_VER_13_3_PLUS,
3333
MACOS_VER_14_0_PLUS,
34+
MACOS_VER_15_0_PLUS,
3435
};
3536

3637
//-----------------------------------------------------------------

aten/src/ATen/mps/MPSDevice.mm

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
3737
if (!_mtl_indexing_library) {
3838
MTLCompileOptions* options = [MTLCompileOptions new];
3939
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
40+
#if defined(__MAC_15_0)
41+
options.mathMode = MTLMathModeFast;
42+
#else
4043
[options setFastMathEnabled:YES];
44+
#endif
4145
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
4246
encoding:NSASCIIStringEncoding]
4347
options:options
@@ -118,6 +122,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
118122
static bool _macos_13_3_plus = [compileOptions respondsToSelector:@selector(maxTotalThreadsPerThreadgroup)] == YES;
119123

120124
static bool _macos_14_0_plus = [mpsCD instancesRespondToSelector:@selector(conjugateWithTensor:name:)] == YES;
125+
static bool _macos_15_0_plus = [mpsCD respondsToSelector:@selector(variableFromTensorWithTensor:name:)] == YES;
121126

122127
switch (version) {
123128
case MacOSVersion::MACOS_VER_13_0_PLUS:
@@ -130,6 +135,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
130135
return _macos_13_3_plus;
131136
case MacOSVersion::MACOS_VER_14_0_PLUS:
132137
return _macos_14_0_plus;
138+
case MacOSVersion::MACOS_VER_15_0_PLUS:
139+
return _macos_15_0_plus;
133140
default:
134141
return false;
135142
}

aten/src/ATen/native/Convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,7 @@ at::Tensor _convolution(
16651665
"Input type (", input.toString(), ") and bias type (", bias.toString(),
16661666
") should be the same");
16671667

1668-
output = at::_mps_convolution(input.contiguous(), weight, bias.defined() ? bias.contiguous() : bias,
1668+
output = at::_mps_convolution(input, weight, bias.defined() ? bias.contiguous() : bias,
16691669
params.padding, params.stride, params.dilation,
16701670
params.groups);
16711671
#else
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
3+
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
4+
5+
#if !defined(__MAC_15_0) && \
6+
(!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
7+
8+
@interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel
9+
-(MPSNDArray * __nullable) reshapeWithCommandBuffer: (__nullable id <MTLCommandBuffer>) cmdBuf
10+
sourceArray: (MPSNDArray * __nonnull) sourceArray
11+
shape: (MPSShape * __nonnull) shape
12+
destinationArray: (MPSNDArray * __nullable) destinationArray;
13+
@end
14+
15+
@interface MPSNDArrayDescriptor()
16+
@property (readwrite, nonatomic) BOOL preferPackedRows;
17+
@end
18+
19+
@interface MPSNDArray()
20+
-(nonnull instancetype) initWithBuffer:(id<MTLBuffer> _Nonnull) buffer
21+
offset:(NSUInteger) offset
22+
descriptor:(MPSNDArrayDescriptor * _Nonnull) descriptor;
23+
-(MPSNDArray * __nullable) arrayViewWithShape:(MPSShape * _Nullable) shape
24+
strides:(MPSShape * _Nonnull) strides;
25+
@end
26+
27+
#endif

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
8888
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
8989
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
9090

91+
MPSNDArray* getMPSNDArray(const at::Tensor& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
92+
MPSNDArray* getMPSNDArray(const at::Tensor& t, MPSShape* sizes = nil, MPSShape* strides = nil);
9193
// The MPSShape could vary based on memory format
94+
Tensor getTensorView(const Tensor& t, MPSShape* shape);
9295
MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
9396
MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
9497

@@ -100,8 +103,9 @@ class Placeholder {
100103
public:
101104
Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
102105
Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
106+
Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray);
103107
Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
104-
bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid);
108+
bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true);
105109
MPSGraphTensor* getMPSGraphTensor() {
106110
return _placeholder;
107111
}
@@ -431,7 +435,8 @@ inline bool supportedFloatingOrComplexType(const Tensor& t) {
431435

432436

433437
inline bool needsGather(const Tensor& t) {
434-
return !t.is_contiguous() || t.storage_offset();
438+
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
439+
return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()) ;
435440
}
436441

437442
} // namespace at::native::mps

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 209 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/TensorIterator.h>
44
#include <ATen/mps/MPSAllocatorInterface.h>
55
#include <ATen/mps/MPSProfiler.h>
6+
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
67
#include <ATen/native/mps/MPSGraphSonomaOps.h>
78
#include <ATen/native/mps/MPSGraphVenturaOps.h>
89
#include <ATen/native/mps/OperationUtils.h>
@@ -303,6 +304,16 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
303304
return str;
304305
}
305306

307+
Tensor getTensorView(const Tensor& t, MPSShape* shape) {
308+
std::vector<int64_t> res;
309+
res.reserve([shape count]);
310+
for (NSNumber* elem in shape) {
311+
res.push_back(elem.longLongValue);
312+
}
313+
IntArrayRef r = IntArrayRef(res);
314+
return t.view(res);
315+
}
316+
306317
MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format) {
307318
return getMPSShape(t.sizes(), memory_format);
308319
}
@@ -359,26 +370,152 @@ void printTensorNDArray(const Tensor& t) {
359370
return [tmpGraphTensorData mpsndarray];
360371
}
361372

373+
static std::vector<int64_t> getSortedStrides(const IntArrayRef& s) {
374+
std::vector<int64_t> idx(s.size());
375+
iota(idx.begin(), idx.end(), 0);
376+
sort(idx.begin(), idx.end(), [&s](size_t i1, size_t i2) { return s[i1] > s[i2]; });
377+
378+
return idx;
379+
}
380+
381+
static std::vector<int64_t> inversePermutation(std::vector<int64_t> permuteOrder) {
382+
auto size = permuteOrder.size();
383+
std::vector<int64_t> inversePerm(permuteOrder.size());
384+
385+
for (int i = 0; i < size; i++) {
386+
inversePerm[permuteOrder[i]] = i;
387+
}
388+
return inversePerm;
389+
}
390+
391+
static MPSNDArray* permuteNDArray(MPSNDArray* inArray, std::vector<int64_t> permuteOrder_) {
392+
auto permuteOrder = inversePermutation(permuteOrder_);
393+
NSUInteger srcRank = [inArray numberOfDimensions];
394+
if (srcRank != permuteOrder.size()) {
395+
TORCH_INTERNAL_ASSERT(false);
396+
return nil;
397+
}
398+
std::vector<NSUInteger> dimensionOrder(srcRank);
399+
std::iota(std::begin(dimensionOrder), std::end(dimensionOrder), 0);
400+
MPSNDArrayDescriptor* desc = [inArray descriptor];
401+
402+
for (int64_t i = srcRank - 1; i >= 0; i--) {
403+
NSUInteger axis = permuteOrder[i];
404+
auto axisIter = std::find(dimensionOrder.begin(), dimensionOrder.end(), axis);
405+
NSUInteger axis1 = srcRank - i - 1;
406+
NSUInteger axis2 = dimensionOrder.end() - axisIter - 1;
407+
iter_swap(dimensionOrder.begin() + i, axisIter);
408+
if (axis1 != axis2) {
409+
[desc transposeDimension:axis1 withDimension:axis2];
410+
}
411+
}
412+
C10_CLANG_DIAGNOSTIC_PUSH()
413+
#if C10_CLANG_HAS_WARNING("-Wnonnull")
414+
C10_CLANG_DIAGNOSTIC_IGNORE("-Wnonnull")
415+
#endif
416+
MPSNDArray* result = [inArray arrayViewWithCommandBuffer:nil descriptor:desc aliasing:MPSAliasingStrategyShallAlias];
417+
C10_CLANG_DIAGNOSTIC_POP()
418+
419+
TORCH_INTERNAL_ASSERT(result != nil);
420+
return result;
421+
}
422+
423+
MPSNDArray* getMPSNDArray(const at::Tensor& t, MPSShape* sizes, MPSShape* strides) {
424+
id<MTLBuffer> srcBuf = getMTLBufferStorage(t);
425+
426+
MPSDataType mpsDataType = getMPSDataType(t.scalar_type());
427+
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType shape:sizes];
428+
srcTensorDesc.preferPackedRows = YES;
429+
MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf
430+
offset:t.storage_offset() * t.element_size()
431+
descriptor:srcTensorDesc] autorelease];
432+
if (strides != nil) {
433+
srcNDArray = [srcNDArray arrayViewWithShape:sizes strides:strides];
434+
}
435+
return srcNDArray;
436+
}
437+
438+
MPSNDArray* getMPSNDArray(const at::Tensor& t, const IntArrayRef& sizes, const IntArrayRef& strides) {
439+
return getMPSNDArray(t, getMPSShape(sizes.empty() ? t.sizes() : sizes), strides.empty() ? nil : getMPSShape(strides));
440+
}
441+
442+
static MPSNDArray* getStridedMPSNDArray(const at::Tensor& src, MPSNDArray* srcNDArray) {
443+
auto strides = src.strides();
444+
auto sizes = src.sizes();
445+
auto nStrides = strides.size();
446+
auto nonZeroStrides = src.strides();
447+
int64_t crtNonZeroStride = 1;
448+
bool hasZeroStrides = false;
449+
auto sortedStridesIndices = getSortedStrides(nonZeroStrides);
450+
451+
NSMutableArray<NSNumber*>* sortedStridesShape = [NSMutableArray arrayWithCapacity:nStrides];
452+
NSMutableArray<NSNumber*>* sortedMPSShape = [NSMutableArray arrayWithCapacity:nStrides];
453+
for (const auto i : c10::irange(nStrides)) {
454+
sortedStridesShape[i] = [NSNumber numberWithInteger:nonZeroStrides[sortedStridesIndices[i]]];
455+
sortedMPSShape[i] = [NSNumber numberWithInteger:sizes[sortedStridesIndices[i]]];
456+
}
457+
MPSShape* originalSortedMPSShape = sortedMPSShape;
458+
MPSShape* originalSortedStridesShape = sortedStridesShape;
459+
bool hasNonZeroStrides = nStrides == 0 ? false : nonZeroStrides[sortedStridesIndices[nStrides - 1]] != 1;
460+
if (hasNonZeroStrides) {
461+
originalSortedMPSShape = [sortedMPSShape copy];
462+
originalSortedStridesShape = [sortedStridesShape copy];
463+
[sortedStridesShape addObject:[NSNumber numberWithInteger:1]];
464+
[sortedMPSShape addObject:[NSNumber numberWithInteger:1]];
465+
}
466+
if (nStrides == 0) {
467+
originalSortedMPSShape = getMPSShape(src);
468+
originalSortedStridesShape = getMPSShape(src.strides());
469+
}
470+
471+
srcNDArray = [srcNDArray arrayViewWithShape:sortedMPSShape strides:sortedStridesShape];
472+
if (hasNonZeroStrides) {
473+
MPSNDArrayIdentity* identity =
474+
[[[MPSNDArrayIdentity alloc] initWithDevice:MPSDevice::getInstance()->device()] autorelease];
475+
srcNDArray = [identity reshapeWithCommandBuffer:nil
476+
sourceArray:srcNDArray
477+
shape:originalSortedMPSShape
478+
destinationArray:nil];
479+
}
480+
TORCH_INTERNAL_ASSERT(srcNDArray);
481+
482+
srcNDArray = permuteNDArray(srcNDArray, sortedStridesIndices);
483+
TORCH_INTERNAL_ASSERT(srcNDArray);
484+
485+
return srcNDArray;
486+
}
487+
488+
Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray) {
489+
_placeholder = mpsGraphTensor;
490+
_value = [[[MPSGraphTensorData alloc] initWithMPSNDArray:mpsNDArray] autorelease];
491+
}
492+
362493
Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
363494
const Tensor& src,
364-
MPSShape* mpsShape,
495+
MPSShape* mpsShape_,
365496
bool gatherTensorData,
366-
MPSDataType dataType)
497+
MPSDataType dataType,
498+
bool useMPSStridedAPI)
367499
: _tensor(src) {
368500
TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!");
369501
// extract the pointer to MTLBuffer from the Tensor's storage
370502
id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
371-
// a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose())
372-
if (needsGather(src) && gatherTensorData) {
373-
Tensor emptyShell = Tensor();
374-
// use "_tensor" from Placeholder to retain view's output during its usage in other ops
375-
_tensor = gatherViewTensor(src, emptyShell);
376-
if (!_tensor.has_storage()) {
377-
// if we cannot gather, we make the tensor contiguous implicitly, and keep
378-
// it in placeholder to be able to retrieve it when we return from constructor
379-
_tensor = src.clone(MemoryFormat::Contiguous);
503+
504+
const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
505+
// Use gather kernel to solve strides for macOS < 15.0
506+
// Starting with macOS 15.0, MPS supports native strides direclty in the kernels
507+
if (!is_macOS_15_0_or_newer || !useMPSStridedAPI) {
508+
if ((!src.is_contiguous() || src.storage_offset()) && gatherTensorData) {
509+
Tensor emptyShell = Tensor();
510+
// use "_tensor" from Placeholder to retain view's output during its usage in other ops
511+
_tensor = gatherViewTensor(src, emptyShell);
512+
if (!_tensor.has_storage()) {
513+
// if we cannot gather, we make the tensor contiguous implicitly, and keep
514+
// it in placeholder to be able to retrieve it when we return from constructor
515+
_tensor = src.clone(MemoryFormat::Contiguous);
516+
}
517+
srcBuf = getMTLBufferStorage(_tensor);
380518
}
381-
srcBuf = getMTLBufferStorage(_tensor);
382519
}
383520

384521
// tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero.
@@ -389,9 +526,66 @@ void printTensorNDArray(const Tensor& t) {
389526
const auto scalar_type = _tensor.scalar_type();
390527
dataType = _tensor.dim() == 0 ? getMPSScalarType(scalar_type) : getMPSDataType(scalar_type);
391528
}
392-
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
393-
shape:mpsShape ? mpsShape : getMPSShape(_tensor)
394-
dataType:dataType] autorelease];
529+
530+
// Tensor is contiguous and has no storage offset.
531+
// Wrap it directly inside MPSGraphTensorData
532+
if ((_tensor.is_contiguous() && !_tensor.storage_offset()) || !useMPSStridedAPI || !is_macOS_15_0_or_newer) {
533+
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
534+
shape:mpsShape_ ? mpsShape_ : getMPSShape(_tensor)
535+
dataType:dataType] autorelease];
536+
} else {
537+
IntArrayRef view_shape;
538+
if (mpsShape_) {
539+
_tensor = getTensorView(src, mpsShape_);
540+
}
541+
542+
MPSShape* mpsShape = getMPSShape(_tensor);
543+
MPSShape* mpsStrides = getMPSShape(_tensor.strides());
544+
545+
IntArrayRef baseShape;
546+
if (src.is_view()) {
547+
baseShape = src._base().sizes();
548+
} else {
549+
baseShape = getIMPSAllocator()->getBufferShape(src.storage().data());
550+
}
551+
int flattenedShaped = 1;
552+
for (const auto i : c10::irange(baseShape.size())) {
553+
flattenedShaped *= baseShape[i];
554+
}
555+
MPSShape* mpsBaseShape = @[ @(flattenedShaped) ];
556+
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:mpsBaseShape];
557+
srcTensorDesc.preferPackedRows = YES;
558+
MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf
559+
offset:src.storage_offset() * src.element_size()
560+
descriptor:srcTensorDesc] autorelease];
561+
TORCH_INTERNAL_ASSERT(srcNDArray);
562+
if (src.dim() != 0) {
563+
srcNDArray = getStridedMPSNDArray(_tensor, srcNDArray);
564+
} else {
565+
bool needsReshape = false;
566+
NSMutableArray* mpsExpandedShape = nil;
567+
NSMutableArray* mpsExpandedStrides = nil;
568+
569+
if (src.dim() > 0 && src.stride(-1) != 1) {
570+
needsReshape = true;
571+
mpsExpandedShape = [NSMutableArray arrayWithArray:mpsShape];
572+
mpsExpandedStrides = [NSMutableArray arrayWithArray:mpsStrides];
573+
[mpsExpandedShape addObject:@1];
574+
[mpsExpandedStrides addObject:@1];
575+
}
576+
srcNDArray = [srcNDArray arrayViewWithShape:needsReshape ? mpsExpandedShape : getMPSShape(src)
577+
strides:needsReshape ? mpsExpandedStrides : getMPSShape(src.strides())];
578+
TORCH_INTERNAL_ASSERT(srcNDArray);
579+
580+
if (needsReshape) {
581+
MPSNDArrayIdentity* identity =
582+
[[[MPSNDArrayIdentity alloc] initWithDevice:MPSDevice::getInstance()->device()] autorelease];
583+
srcNDArray = [identity reshapeWithCommandBuffer:nil sourceArray:srcNDArray shape:mpsShape destinationArray:nil];
584+
}
585+
TORCH_INTERNAL_ASSERT(srcNDArray);
586+
}
587+
_value = [[[MPSGraphTensorData alloc] initWithMPSNDArray:srcNDArray] autorelease];
588+
}
395589

396590
TORCH_INTERNAL_ASSERT(_value);
397591
_placeholder = mpsGraphTensor;

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy