Skip to content

Commit e35e316

Browse files
Revert "[MPS][BE] Delete unused lerp functors (#152443)"
This reverts commit 0a2d320. Reverted #152443 on behalf of https://github.com/wdvr due to failing MPS test: test/test_optim.py::TestOptimRenewedMPS::test_can_load_from_to_named_state_dict_is_named_optim0_False_is_named_optim1_False_Adafactor_mps_float32 ([comment](#152443 (comment)))
1 parent fecaa60 commit e35e316

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

aten/src/ATen/native/mps/kernels/BinaryKernel.metal

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ struct sub_functor {
1818
}
1919
};
2020

21+
struct lerp_functor {
22+
template <typename T>
23+
inline T operator()(const T a, const T b) {
24+
return static_cast<T>(b);
25+
}
26+
};
27+
2128
struct add_alpha_functor {
2229
template <typename T>
2330
inline T operator()(const T a, const T b, const T alpha) {
@@ -222,6 +229,13 @@ struct complex_lerp_alpha_functor {
222229
}
223230
};
224231

232+
struct complex_lerp_functor {
233+
template <typename T>
234+
inline T operator()(const T a, const T b) {
235+
return T(b.x, b.y);
236+
}
237+
};
238+
225239
REGISTER_BINARY_OP(copysign, long, float);
226240
REGISTER_BINARY_OP(copysign, int, float);
227241
REGISTER_BINARY_OP(copysign, float, float);
@@ -268,6 +282,14 @@ REGISTER_BINARY_OP(sub, short, short);
268282
REGISTER_BINARY_OP(sub, uchar, uchar);
269283
REGISTER_BINARY_OP(sub, char, char);
270284
REGISTER_BINARY_OP(sub, bool, bool);
285+
REGISTER_BINARY_OP(lerp, long, long);
286+
REGISTER_BINARY_OP(lerp, int, int);
287+
REGISTER_BINARY_OP(lerp, float, float);
288+
REGISTER_BINARY_OP(lerp, half, half);
289+
REGISTER_BINARY_OP(lerp, short, short);
290+
REGISTER_BINARY_OP(lerp, uchar, uchar);
291+
REGISTER_BINARY_OP(lerp, char, char);
292+
REGISTER_BINARY_OP(lerp, bool, bool);
271293
REGISTER_BINARY_ALPHA_OP(add_alpha, long, long);
272294
REGISTER_BINARY_ALPHA_OP(add_alpha, int, int);
273295
REGISTER_BINARY_ALPHA_OP(add_alpha, float, float);
@@ -308,6 +330,7 @@ REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat);
308330
REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat);
309331
REGISTER_BINARY_OP(add, bfloat, bfloat);
310332
REGISTER_BINARY_OP(sub, bfloat, bfloat);
333+
REGISTER_BINARY_OP(lerp, bfloat, bfloat);
311334
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat);
312335
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat);
313336
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat);
@@ -324,6 +347,8 @@ REGISTER_BINARY_OP(add, float2, float2);
324347
REGISTER_BINARY_OP(add, half2, half2);
325348
REGISTER_BINARY_OP(sub, float2, float2);
326349
REGISTER_BINARY_OP(sub, half2, half2);
350+
REGISTER_BINARY_OP(lerp, float2, float2);
351+
REGISTER_BINARY_OP(lerp, half2, half2);
327352
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2);
328353
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2);
329354
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2);

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy