3
3
#include < ATen/TensorIterator.h>
4
4
#include < ATen/mps/MPSAllocatorInterface.h>
5
5
#include < ATen/mps/MPSProfiler.h>
6
+ #include < ATen/native/mps/MPSGraphSequoiaOps.h>
6
7
#include < ATen/native/mps/MPSGraphSonomaOps.h>
7
8
#include < ATen/native/mps/MPSGraphVenturaOps.h>
8
9
#include < ATen/native/mps/OperationUtils.h>
@@ -303,6 +304,16 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
303
304
return str;
304
305
}
305
306
307
+ Tensor getTensorView (const Tensor& t, MPSShape* shape) {
308
+ std::vector<int64_t > res;
309
+ res.reserve ([shape count ]);
310
+ for (NSNumber * elem in shape) {
311
+ res.push_back (elem.longLongValue );
312
+ }
313
+ IntArrayRef r = IntArrayRef (res);
314
+ return t.view (res);
315
+ }
316
+
306
317
MPSShape* getMPSShape (const Tensor& t, c10::MemoryFormat memory_format) {
307
318
return getMPSShape (t.sizes (), memory_format);
308
319
}
@@ -359,26 +370,152 @@ void printTensorNDArray(const Tensor& t) {
359
370
return [tmpGraphTensorData mpsndarray ];
360
371
}
361
372
373
+ static std::vector<int64_t > getSortedStrides (const IntArrayRef& s) {
374
+ std::vector<int64_t > idx (s.size ());
375
+ iota (idx.begin (), idx.end (), 0 );
376
+ sort (idx.begin (), idx.end (), [&s](size_t i1, size_t i2) { return s[i1] > s[i2]; });
377
+
378
+ return idx;
379
+ }
380
+
381
+ static std::vector<int64_t > inversePermutation (std::vector<int64_t > permuteOrder) {
382
+ auto size = permuteOrder.size ();
383
+ std::vector<int64_t > inversePerm (permuteOrder.size ());
384
+
385
+ for (int i = 0 ; i < size; i++) {
386
+ inversePerm[permuteOrder[i]] = i;
387
+ }
388
+ return inversePerm;
389
+ }
390
+
391
+ static MPSNDArray* permuteNDArray (MPSNDArray* inArray, std::vector<int64_t > permuteOrder_) {
392
+ auto permuteOrder = inversePermutation (permuteOrder_);
393
+ NSUInteger srcRank = [inArray numberOfDimensions ];
394
+ if (srcRank != permuteOrder.size ()) {
395
+ TORCH_INTERNAL_ASSERT (false );
396
+ return nil ;
397
+ }
398
+ std::vector<NSUInteger > dimensionOrder (srcRank);
399
+ std::iota (std::begin (dimensionOrder), std::end (dimensionOrder), 0 );
400
+ MPSNDArrayDescriptor* desc = [inArray descriptor ];
401
+
402
+ for (int64_t i = srcRank - 1 ; i >= 0 ; i--) {
403
+ NSUInteger axis = permuteOrder[i];
404
+ auto axisIter = std::find (dimensionOrder.begin (), dimensionOrder.end (), axis);
405
+ NSUInteger axis1 = srcRank - i - 1 ;
406
+ NSUInteger axis2 = dimensionOrder.end () - axisIter - 1 ;
407
+ iter_swap (dimensionOrder.begin () + i, axisIter);
408
+ if (axis1 != axis2) {
409
+ [desc transposeDimension: axis1 withDimension: axis2];
410
+ }
411
+ }
412
+ C10_CLANG_DIAGNOSTIC_PUSH ()
413
+ #if C10_CLANG_HAS_WARNING("-Wnonnull")
414
+ C10_CLANG_DIAGNOSTIC_IGNORE (" -Wnonnull" )
415
+ #endif
416
+ MPSNDArray* result = [inArray arrayViewWithCommandBuffer: nil descriptor: desc aliasing: MPSAliasingStrategyShallAlias];
417
+ C10_CLANG_DIAGNOSTIC_POP ()
418
+
419
+ TORCH_INTERNAL_ASSERT (result != nil );
420
+ return result;
421
+ }
422
+
423
+ MPSNDArray* getMPSNDArray (const at::Tensor& t, MPSShape* sizes, MPSShape* strides) {
424
+ id <MTLBuffer > srcBuf = getMTLBufferStorage (t);
425
+
426
+ MPSDataType mpsDataType = getMPSDataType (t.scalar_type ());
427
+ MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType: mpsDataType shape: sizes];
428
+ srcTensorDesc.preferPackedRows = YES ;
429
+ MPSNDArray* srcNDArray = [[[MPSNDArray alloc ] initWithBuffer: srcBuf
430
+ offset: t.storage_offset () * t.element_size ()
431
+ descriptor: srcTensorDesc] autorelease ];
432
+ if (strides != nil ) {
433
+ srcNDArray = [srcNDArray arrayViewWithShape: sizes strides: strides];
434
+ }
435
+ return srcNDArray;
436
+ }
437
+
438
+ MPSNDArray* getMPSNDArray (const at::Tensor& t, const IntArrayRef& sizes, const IntArrayRef& strides) {
439
+ return getMPSNDArray (t, getMPSShape (sizes.empty () ? t.sizes () : sizes), strides.empty () ? nil : getMPSShape (strides));
440
+ }
441
+
442
+ static MPSNDArray* getStridedMPSNDArray (const at::Tensor& src, MPSNDArray* srcNDArray) {
443
+ auto strides = src.strides ();
444
+ auto sizes = src.sizes ();
445
+ auto nStrides = strides.size ();
446
+ auto nonZeroStrides = src.strides ();
447
+ int64_t crtNonZeroStride = 1 ;
448
+ bool hasZeroStrides = false ;
449
+ auto sortedStridesIndices = getSortedStrides (nonZeroStrides);
450
+
451
+ NSMutableArray <NSNumber *>* sortedStridesShape = [NSMutableArray arrayWithCapacity: nStrides];
452
+ NSMutableArray <NSNumber *>* sortedMPSShape = [NSMutableArray arrayWithCapacity: nStrides];
453
+ for (const auto i : c10::irange (nStrides)) {
454
+ sortedStridesShape[i] = [NSNumber numberWithInteger: nonZeroStrides[sortedStridesIndices[i]]];
455
+ sortedMPSShape[i] = [NSNumber numberWithInteger: sizes[sortedStridesIndices[i]]];
456
+ }
457
+ MPSShape* originalSortedMPSShape = sortedMPSShape;
458
+ MPSShape* originalSortedStridesShape = sortedStridesShape;
459
+ bool hasNonZeroStrides = nStrides == 0 ? false : nonZeroStrides[sortedStridesIndices[nStrides - 1 ]] != 1 ;
460
+ if (hasNonZeroStrides) {
461
+ originalSortedMPSShape = [sortedMPSShape copy ];
462
+ originalSortedStridesShape = [sortedStridesShape copy ];
463
+ [sortedStridesShape addObject: [NSNumber numberWithInteger: 1 ]];
464
+ [sortedMPSShape addObject: [NSNumber numberWithInteger: 1 ]];
465
+ }
466
+ if (nStrides == 0 ) {
467
+ originalSortedMPSShape = getMPSShape (src);
468
+ originalSortedStridesShape = getMPSShape (src.strides ());
469
+ }
470
+
471
+ srcNDArray = [srcNDArray arrayViewWithShape: sortedMPSShape strides: sortedStridesShape];
472
+ if (hasNonZeroStrides) {
473
+ MPSNDArrayIdentity* identity =
474
+ [[[MPSNDArrayIdentity alloc ] initWithDevice: MPSDevice: :getInstance ()->device ()] autorelease ];
475
+ srcNDArray = [identity reshapeWithCommandBuffer: nil
476
+ sourceArray: srcNDArray
477
+ shape: originalSortedMPSShape
478
+ destinationArray: nil ];
479
+ }
480
+ TORCH_INTERNAL_ASSERT (srcNDArray);
481
+
482
+ srcNDArray = permuteNDArray (srcNDArray, sortedStridesIndices);
483
+ TORCH_INTERNAL_ASSERT (srcNDArray);
484
+
485
+ return srcNDArray;
486
+ }
487
+
488
+ Placeholder::Placeholder (MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray) {
489
+ _placeholder = mpsGraphTensor;
490
+ _value = [[[MPSGraphTensorData alloc ] initWithMPSNDArray: mpsNDArray] autorelease ];
491
+ }
492
+
362
493
Placeholder::Placeholder (MPSGraphTensor* mpsGraphTensor,
363
494
const Tensor& src,
364
- MPSShape* mpsShape ,
495
+ MPSShape* mpsShape_ ,
365
496
bool gatherTensorData,
366
- MPSDataType dataType)
497
+ MPSDataType dataType,
498
+ bool useMPSStridedAPI)
367
499
: _tensor(src) {
368
500
TORCH_CHECK (src.is_mps (), " Placeholder storage has not been allocated on MPS device!" );
369
501
// extract the pointer to MTLBuffer from the Tensor's storage
370
502
id <MTLBuffer > srcBuf = getMTLBufferStorage (src);
371
- // a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose())
372
- if (needsGather (src) && gatherTensorData) {
373
- Tensor emptyShell = Tensor ();
374
- // use "_tensor" from Placeholder to retain view's output during its usage in other ops
375
- _tensor = gatherViewTensor (src, emptyShell);
376
- if (!_tensor.has_storage ()) {
377
- // if we cannot gather, we make the tensor contiguous implicitly, and keep
378
- // it in placeholder to be able to retrieve it when we return from constructor
379
- _tensor = src.clone (MemoryFormat::Contiguous);
503
+
504
+ const bool is_macOS_15_0_or_newer = is_macos_13_or_newer (MacOSVersion::MACOS_VER_15_0_PLUS);
505
+ // Use gather kernel to solve strides for macOS < 15.0
506
+ // Starting with macOS 15.0, MPS supports native strides direclty in the kernels
507
+ if (!is_macOS_15_0_or_newer || !useMPSStridedAPI) {
508
+ if ((!src.is_contiguous () || src.storage_offset ()) && gatherTensorData) {
509
+ Tensor emptyShell = Tensor ();
510
+ // use "_tensor" from Placeholder to retain view's output during its usage in other ops
511
+ _tensor = gatherViewTensor (src, emptyShell);
512
+ if (!_tensor.has_storage ()) {
513
+ // if we cannot gather, we make the tensor contiguous implicitly, and keep
514
+ // it in placeholder to be able to retrieve it when we return from constructor
515
+ _tensor = src.clone (MemoryFormat::Contiguous);
516
+ }
517
+ srcBuf = getMTLBufferStorage (_tensor);
380
518
}
381
- srcBuf = getMTLBufferStorage (_tensor);
382
519
}
383
520
384
521
// tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero.
@@ -389,9 +526,66 @@ void printTensorNDArray(const Tensor& t) {
389
526
const auto scalar_type = _tensor.scalar_type ();
390
527
dataType = _tensor.dim () == 0 ? getMPSScalarType (scalar_type) : getMPSDataType (scalar_type);
391
528
}
392
- _value = [[[MPSGraphTensorData alloc ] initWithMTLBuffer: srcBuf
393
- shape: mpsShape ? mpsShape : getMPSShape (_tensor)
394
- dataType: dataType] autorelease ];
529
+
530
+ // Tensor is contiguous and has no storage offset.
531
+ // Wrap it directly inside MPSGraphTensorData
532
+ if ((_tensor.is_contiguous () && !_tensor.storage_offset ()) || !useMPSStridedAPI || !is_macOS_15_0_or_newer) {
533
+ _value = [[[MPSGraphTensorData alloc ] initWithMTLBuffer: srcBuf
534
+ shape: mpsShape_ ? mpsShape_ : getMPSShape (_tensor)
535
+ dataType: dataType] autorelease ];
536
+ } else {
537
+ IntArrayRef view_shape;
538
+ if (mpsShape_) {
539
+ _tensor = getTensorView (src, mpsShape_);
540
+ }
541
+
542
+ MPSShape* mpsShape = getMPSShape (_tensor);
543
+ MPSShape* mpsStrides = getMPSShape (_tensor.strides ());
544
+
545
+ IntArrayRef baseShape;
546
+ if (src.is_view ()) {
547
+ baseShape = src._base ().sizes ();
548
+ } else {
549
+ baseShape = getIMPSAllocator ()->getBufferShape (src.storage ().data ());
550
+ }
551
+ int flattenedShaped = 1 ;
552
+ for (const auto i : c10::irange (baseShape.size ())) {
553
+ flattenedShaped *= baseShape[i];
554
+ }
555
+ MPSShape* mpsBaseShape = @[ @(flattenedShaped) ];
556
+ MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType: dataType shape: mpsBaseShape];
557
+ srcTensorDesc.preferPackedRows = YES ;
558
+ MPSNDArray* srcNDArray = [[[MPSNDArray alloc ] initWithBuffer: srcBuf
559
+ offset: src.storage_offset () * src.element_size ()
560
+ descriptor: srcTensorDesc] autorelease ];
561
+ TORCH_INTERNAL_ASSERT (srcNDArray);
562
+ if (src.dim () != 0 ) {
563
+ srcNDArray = getStridedMPSNDArray (_tensor, srcNDArray);
564
+ } else {
565
+ bool needsReshape = false ;
566
+ NSMutableArray * mpsExpandedShape = nil ;
567
+ NSMutableArray * mpsExpandedStrides = nil ;
568
+
569
+ if (src.dim () > 0 && src.stride (-1 ) != 1 ) {
570
+ needsReshape = true ;
571
+ mpsExpandedShape = [NSMutableArray arrayWithArray: mpsShape];
572
+ mpsExpandedStrides = [NSMutableArray arrayWithArray: mpsStrides];
573
+ [mpsExpandedShape addObject: @1 ];
574
+ [mpsExpandedStrides addObject: @1 ];
575
+ }
576
+ srcNDArray = [srcNDArray arrayViewWithShape: needsReshape ? mpsExpandedShape : getMPSShape (src)
577
+ strides: needsReshape ? mpsExpandedStrides : getMPSShape (src.strides ())];
578
+ TORCH_INTERNAL_ASSERT (srcNDArray);
579
+
580
+ if (needsReshape) {
581
+ MPSNDArrayIdentity* identity =
582
+ [[[MPSNDArrayIdentity alloc ] initWithDevice: MPSDevice: :getInstance ()->device ()] autorelease ];
583
+ srcNDArray = [identity reshapeWithCommandBuffer: nil sourceArray: srcNDArray shape: mpsShape destinationArray: nil ];
584
+ }
585
+ TORCH_INTERNAL_ASSERT (srcNDArray);
586
+ }
587
+ _value = [[[MPSGraphTensorData alloc ] initWithMPSNDArray: srcNDArray] autorelease ];
588
+ }
395
589
396
590
TORCH_INTERNAL_ASSERT (_value);
397
591
_placeholder = mpsGraphTensor;
0 commit comments