Skip to content

Commit 770e10a

Browse files
mooskaghtensorflower-gardener
authored andcommitted
[XLA:GPU] Move Dot strength reduction out of algebraic simplifier
and run it only once. The plan for the follow up changes is to remove vec×matrix reduction (currently regresses some models for unrelated reasons), and only keep vec×vec. PiperOrigin-RevId: 784472699
1 parent 835af34 commit 770e10a

14 files changed

+636
-225
lines changed

third_party/xla/xla/service/gpu/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,7 @@ cc_library(
10231023
"//xla:xla_data_proto_cc",
10241024
"//xla/hlo/ir:hlo",
10251025
"//xla/service:algorithm_util",
1026+
"//xla/service/gpu/transforms:dot_algorithm_rewriter",
10261027
"//xla/stream_executor:blas",
10271028
"//xla/stream_executor:device_description",
10281029
"//xla/stream_executor:device_memory",
@@ -1037,7 +1038,6 @@ cc_library(
10371038
"@com_google_absl//absl/status:statusor",
10381039
"@com_google_absl//absl/strings",
10391040
"@com_google_absl//absl/types:span",
1040-
"@llvm-project//llvm:Support",
10411041
"@local_tsl//tsl/platform:errors",
10421042
"@local_tsl//tsl/platform:statusor",
10431043
],
@@ -1609,6 +1609,7 @@ cc_library(
16091609
"//xla/service/gpu/transforms:dot_dimension_sorter",
16101610
"//xla/service/gpu/transforms:dot_normalizer",
16111611
"//xla/service/gpu/transforms:dot_operand_converter",
1612+
"//xla/service/gpu/transforms:dot_strength_reduction",
16121613
"//xla/service/gpu/transforms:double_buffer_loop_unrolling",
16131614
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
16141615
"//xla/service/gpu/transforms:explicit_collectives_group_async_wrapper",

third_party/xla/xla/service/gpu/gpu_compiler.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ limitations under the License.
213213
#include "xla/service/gpu/transforms/dot_dimension_sorter.h"
214214
#include "xla/service/gpu/transforms/dot_normalizer.h"
215215
#include "xla/service/gpu/transforms/dot_operand_converter.h"
216+
#include "xla/service/gpu/transforms/dot_strength_reduction.h"
216217
#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h"
217218
#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
218219
#include "xla/service/gpu/transforms/explicit_collectives_group_async_wrapper.h"
@@ -849,6 +850,8 @@ absl::Status RunOptimizationPasses(
849850
pipeline.AddPass<ScatterExpander>(
850851
ScatterExpander::kEliminateSimpleScatters);
851852
pipeline.AddPass<ScatterSliceSimplifier>();
853+
pipeline.AddPass<DotStrengthReduction>(
854+
gpu_target_config.device_description.gpu_compute_capability());
852855
pipeline.AddPass<GpuAlgebraicSimplifier>(layout_insensitive_algsimp_opts,
853856
gpu_version);
854857
pipeline.AddPass<BitcastDtypesExpander>();
@@ -1348,7 +1351,7 @@ AlgebraicSimplifierOptions GpuCompiler::GetAlgebraicSimplifierOptions(
13481351
bool is_rocm) {
13491352
AlgebraicSimplifierOptions opts;
13501353

1351-
opts.set_enable_dot_strength_reduction(true);
1354+
opts.set_enable_dot_strength_reduction(false);
13521355
// On GPU it helps to reorder them so that the fused cuDNN kernel can be
13531356
// used.
13541357
opts.set_enable_conv_add_multiply_reorder(true);

third_party/xla/xla/service/gpu/ir_emission_utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ absl::StatusOr<bool> IsCublasSupportedMatMul(
8181
int num_matrix_operands = 0;
8282
for (int operand : {0, 1}) {
8383
TF_ASSIGN_OR_RETURN(DotOperandDims dims,
84-
DotOperandDims::FromDot(&dot, operand));
84+
DotOperandDims::FromDotOperand(&dot, operand));
8585
// cuBLAS only supports single contracting dimension.
8686
if (dims.DimensionCount(DotOperandDims::kContracting) != 1) {
8787
return false;

third_party/xla/xla/service/gpu/matmul_indexing_utils.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "xla/service/gpu/matmul_indexing_utils.h"
1717

18+
#include <array>
1819
#include <cstdint>
1920
#include <iterator>
2021
#include <vector>
@@ -101,7 +102,14 @@ DotOperandDims::DotOperandDims(Shape shape,
101102
contracting_dims.end());
102103
}
103104

104-
absl::StatusOr<DotOperandDims> DotOperandDims::FromDot(
105+
absl::StatusOr<std::array<DotOperandDims, 2>> DotOperandDims::FromDot(
106+
const HloInstruction* dot) {
107+
TF_ASSIGN_OR_RETURN(auto lhs_dims, FromDotOperand(dot, 0));
108+
TF_ASSIGN_OR_RETURN(auto rhs_dims, FromDotOperand(dot, 1));
109+
return std::array<DotOperandDims, 2>{lhs_dims, rhs_dims};
110+
}
111+
112+
absl::StatusOr<DotOperandDims> DotOperandDims::FromDotOperand(
105113
const HloInstruction* dot, int operand_idx) {
106114
TF_RET_CHECK(operand_idx == 0 || operand_idx == 1);
107115
const Shape& shape = dot->operand(operand_idx)->shape();

third_party/xla/xla/service/gpu/matmul_indexing_utils.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,12 @@ class DotOperandDims {
6363
absl::Span<const int64_t> contracting_dims);
6464

6565
enum Category { kBatch, kNonContracting, kContracting };
66+
// Creates a DotOperandDims from a dot instruction.
67+
static absl::StatusOr<std::array<DotOperandDims, 2>> FromDot(
68+
const HloInstruction* dot);
6669
// Creates a DotOperandDims from a dot instruction and operand index (0 or 1).
67-
static absl::StatusOr<DotOperandDims> FromDot(const HloInstruction* dot,
68-
int operand_idx);
70+
static absl::StatusOr<DotOperandDims> FromDotOperand(
71+
const HloInstruction* dot, int operand_idx);
6972
// Converts two DotOperandDims to a DotDimensionNumbers.
7073
static absl::StatusOr<DotDimensionNumbers> IntoDotDimensionNumbers(
7174
const DotOperandDims& lhs_dims, const DotOperandDims& rhs_dims);

third_party/xla/xla/service/gpu/matmul_utils.cc

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ limitations under the License.
4040
#include "xla/service/algorithm_util.h"
4141
#include "xla/service/gpu/backend_configs.pb.h"
4242
#include "xla/service/gpu/matmul_indexing_utils.h"
43+
#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h"
4344
#include "xla/shape.h"
4445
#include "xla/shape_util.h"
4546
#include "xla/status_macros.h"
@@ -895,17 +896,40 @@ PrimitiveType GetGemmAccumulatorType(HloDotInstruction* dot) {
895896
if (accumulator_type.ok()) {
896897
return accumulator_type.value();
897898
}
898-
// Otherwise, return the default accumulator type for the output type.
899-
PrimitiveType output_type = dot->shape().element_type();
900-
switch (output_type) {
901-
case PrimitiveType::F16:
902-
case PrimitiveType::BF16:
903-
return PrimitiveType::F32;
904-
case PrimitiveType::F32:
905-
case PrimitiveType::F64:
906-
case PrimitiveType::S32:
899+
900+
PrimitiveType shape_type = dot->shape().element_type();
901+
// If the output type is a floating point type with less than or equal to 32
902+
// bits, use f32 as the accumulator type.
903+
if (primitive_util::IsFloatingPointType(shape_type) &&
904+
primitive_util::BitWidth(shape_type) <= primitive_util::BitWidth(F32)) {
905+
return F32;
906+
}
907+
return shape_type;
908+
}
909+
910+
absl::StatusOr<HloInstruction*> MakeMultiplyForDotPrecisionAlgorithm(
911+
HloInstruction* lhs, HloInstruction* rhs,
912+
const PrecisionConfig::Algorithm& algorithm) {
913+
switch (algorithm) {
914+
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
915+
return DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32(lhs, rhs);
916+
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
917+
return DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32X3(lhs, rhs);
918+
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
919+
return DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32X6(lhs, rhs);
920+
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9:
921+
return DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32X9(lhs, rhs);
922+
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
923+
return DotAlgorithmRewriter::MakeMultiplyForTF32TF32F32(lhs, rhs);
924+
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
925+
return DotAlgorithmRewriter::MakeMultiplyForTF32TF32F32X3(lhs, rhs);
926+
case PrecisionConfig::ALG_DOT_F32_F32_F32:
927+
case PrecisionConfig::ALG_UNSET:
928+
return lhs->parent()->AddInstruction(HloInstruction::CreateBinary(
929+
lhs->shape(), HloOpcode::kMultiply, lhs, rhs));
907930
default:
908-
return output_type;
931+
return absl::InvalidArgumentError(
932+
absl::StrCat("Unsupported dot precision algorithm: ", algorithm));
909933
}
910934
}
911935

third_party/xla/xla/service/gpu/matmul_utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ bool IsDotSupportedByClassicalEmitters(const HloInstruction& dot);
6363
// from the dot algorithm or inferred from the output type).
6464
PrimitiveType GetGemmAccumulatorType(HloDotInstruction* dot);
6565

66+
// Makes algorithm specific set of instructions which would multiply lhs and rhs
67+
// like the dot with the given precision algorithm would. Useful e.g. rewriting
68+
// dot as multiply+reduce.
69+
absl::StatusOr<HloInstruction*> MakeMultiplyForDotPrecisionAlgorithm(
70+
HloInstruction* lhs, HloInstruction* rhs,
71+
const PrecisionConfig::Algorithm& algorithm);
72+
6673
// extending plain MatrixLayout struct with creator functions
6774
struct MatrixLayout : public se::gpu::MatrixLayout {
6875
// Returns the matrix layout for a logical shape (batch, rows, columns).

third_party/xla/xla/service/gpu/transforms/BUILD

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ xla_cc_test(
8787
"//xla/hlo/transforms/simplifiers:algebraic_simplifier",
8888
"//xla/service:pattern_matcher",
8989
"//xla/stream_executor:device_description",
90+
"//xla/stream_executor/cuda:cuda_compute_capability",
9091
"//xla/tests:xla_internal_test_main",
9192
"//xla/tsl/platform:statusor",
9293
"@com_google_absl//absl/strings:string_view",
9394
"@com_google_googletest//:gtest",
94-
"@local_tsl//tsl/platform:statusor",
9595
],
9696
)
9797

@@ -1182,6 +1182,48 @@ xla_test(
11821182
),
11831183
)
11841184

1185+
cc_library(
1186+
name = "dot_strength_reduction",
1187+
srcs = ["dot_strength_reduction.cc"],
1188+
hdrs = ["dot_strength_reduction.h"],
1189+
deps = [
1190+
":dot_algorithm_rewriter",
1191+
"//xla:literal",
1192+
"//xla:literal_util",
1193+
"//xla:shape_util",
1194+
"//xla/backends/gpu/codegen/triton:support",
1195+
"//xla/hlo/ir:hlo",
1196+
"//xla/hlo/transforms/expanders:op_expander_pass",
1197+
"//xla/service/gpu:matmul_indexing_utils",
1198+
"//xla/service/gpu:matmul_utils",
1199+
"//xla/stream_executor:device_description",
1200+
"//xla/tsl/platform:statusor",
1201+
"@com_google_absl//absl/algorithm:container",
1202+
"@com_google_absl//absl/log:check",
1203+
"@com_google_absl//absl/status:statusor",
1204+
"@com_google_absl//absl/strings:string_view",
1205+
"@com_google_absl//absl/types:span",
1206+
],
1207+
)
1208+
1209+
xla_test(
1210+
name = "dot_strength_reduction_test",
1211+
srcs = ["dot_strength_reduction_test.cc"],
1212+
backends = ["gpu"],
1213+
deps = [
1214+
":dot_strength_reduction",
1215+
"//xla/hlo/ir:hlo",
1216+
"//xla/hlo/testlib:filecheck",
1217+
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
1218+
"//xla/hlo/testlib:verified_hlo_module",
1219+
"//xla/stream_executor:device_description",
1220+
"//xla/stream_executor/cuda:cuda_compute_capability",
1221+
"//xla/tsl/platform:statusor",
1222+
"@com_google_absl//absl/log:check",
1223+
"@com_google_googletest//:gtest_main",
1224+
],
1225+
)
1226+
11851227
cc_library(
11861228
name = "double_buffer_loop_unrolling",
11871229
srcs = ["double_buffer_loop_unrolling.cc"],

third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -100,65 +100,8 @@ bool GpuAlgebraicSimplifierVisitor::SupportedDotPrecisionConfig(
100100
absl::StatusOr<HloInstruction*>
101101
GpuAlgebraicSimplifierVisitor::MakeMultiplyForPrecisionAlgorithm(
102102
HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) {
103-
const auto algorithm = dot->precision_config().algorithm();
104-
switch (algorithm) {
105-
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
106-
return DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32(lhs, rhs);
107-
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
108-
return DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32X3(lhs, rhs);
109-
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
110-
return DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32X6(lhs, rhs);
111-
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9:
112-
return DotAlgorithmRewriter::MakeMultiplyForBF16BF16F32X9(lhs, rhs);
113-
case PrecisionConfig::ALG_DOT_TF32_TF32_F32:
114-
return DotAlgorithmRewriter::MakeMultiplyForTF32TF32F32(lhs, rhs);
115-
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
116-
return DotAlgorithmRewriter::MakeMultiplyForTF32TF32F32X3(lhs, rhs);
117-
case PrecisionConfig::ALG_DOT_F32_F32_F32:
118-
return MakeBinaryHlo(HloOpcode::kMultiply, lhs, rhs);
119-
case PrecisionConfig::ALG_UNSET:
120-
return MakeBinaryHlo(HloOpcode::kMultiply, lhs, rhs);
121-
default:
122-
CHECK(false) << "Unsupported dot precision algorithm: " << algorithm;
123-
}
124-
}
125-
126-
bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce(
127-
const HloInstruction* hlo) {
128-
if (!options_.enable_dot_strength_reduction()) {
129-
return false;
130-
}
131-
132-
const HloDotInstruction* dot = DynCast<HloDotInstruction>(hlo);
133-
if (dot == nullptr) {
134-
return false;
135-
}
136-
137-
const HloInstruction* lhs = dot->operand(0);
138-
const HloInstruction* rhs = dot->operand(1);
139-
DotDimensionNumbers dnums = dot->dot_dimension_numbers();
140-
bool lhs_is_vector = (dnums.lhs_batch_dimensions_size() +
141-
dnums.lhs_contracting_dimensions_size() ==
142-
lhs->shape().dimensions().size());
143-
bool rhs_is_vector = (dnums.rhs_batch_dimensions_size() +
144-
dnums.rhs_contracting_dimensions_size() ==
145-
rhs->shape().dimensions().size());
146-
// Strength-reduce vector-vector dots since they are not supported by
147-
// GemmFusion.
148-
if (lhs_is_vector && rhs_is_vector) {
149-
return true;
150-
}
151-
152-
absl::StatusOr<bool> is_too_small =
153-
IsMatrixMultiplicationTooSmallForRewriting(*hlo, /*threshold=*/10000000);
154-
CHECK_OK(is_too_small.status());
155-
if (is_too_small.value()) {
156-
return true;
157-
}
158-
159-
// If GemmFusion cannot handle this dot, we should strength-reduce it so that
160-
// it can be handled by the fusion pipeline.
161-
return !legacy_triton::CanTritonHandleGEMM(*dot, compute_capability_);
103+
return MakeMultiplyForDotPrecisionAlgorithm(
104+
lhs, rhs, dot->precision_config().algorithm());
162105
}
163106

164107
} // namespace xla::gpu

third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor {
4141

4242
absl::Status HandleAdd(HloInstruction* add) override;
4343

44-
bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override;
45-
4644
private:
4745
// Returns true if the dot precision config is supported by simplifier.
4846
bool SupportedDotPrecisionConfig(const PrecisionConfig& config,

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy