Skip to content

Commit

Permalink
Solve cudainjectiveschedule problem (PaddlePaddle#824)
Browse files Browse the repository at this point in the history
* solve cudainjectiveschedule problem

add log

* delete log

* revert before 807

* add log

* fix bugs

* delete log and fix bug
  • Loading branch information
haozech authored Jun 27, 2022
1 parent 9286980 commit 0fff689
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 22 deletions.
2 changes: 1 addition & 1 deletion cinn/frontend/pass/pass_test_helper.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ void CompareResult(Program* program,

ASSERT_EQ(origin_out.size(), fused_out.size());
for (size_t i = 0; i < origin_out.size(); ++i) {
ASSERT_FLOAT_EQ(origin_out[i], fused_out[i]);
ASSERT_FLOAT_EQ(origin_out[i], fused_out[i]) << " i is " << i;
}
}

Expand Down
Empty file modified cinn/hlir/framework/op_lowering.cc
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion cinn/hlir/pass/test_dot_merger.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ TEST(DotMerger, lhs) {
// because op def changes with the macro
return;
}
int m = 2, k = 10201, n1 = 50, n2 = 50, n3 = 50, axis = 1;
int m = 2, k = 10201, n1 = 100, n2 = 100, n3 = 100, axis = 1;
NetBuilder builder("net_builder");
auto a = builder.CreateInput(Float(32), {m, k}, "A");
auto b = builder.CreateInput(Float(32), {k, n1}, "B");
Expand Down
55 changes: 35 additions & 20 deletions cinn/hlir/pe/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <absl/container/flat_hash_map.h>
#include <isl/cpp.h>
#include <math.h>

#include <algorithm>
#include <fstream>
Expand Down Expand Up @@ -2101,6 +2102,18 @@ int gcd(int a, int b) {
return a;
}

int MaxFactorLessThan(int a, int b) {
CHECK_GT(a, b);
int res = 1;
for (int i = 2; i <= (int)sqrt((double)a); i++) {
if (a % i == 0) {
if (i <= b) res = std::max(res, i);
if (a / i <= b) res = std::max(res, a / i);
}
}
return res;
}

void CudaScheduleInjectiveWithVectorize(poly::Stage *stage,
const std::vector<int> &output_shape,
const common::Target &target) {
Expand Down Expand Up @@ -2174,30 +2187,32 @@ void CudaScheduleInjective(poly::Stage *stage, const std::vector<int> &output_sh
CudaScheduleInjectiveWithVectorize(stage, output_shape, target);
return;
}
int dims = stage->n_out_dims() - 1;
int num_thread = target.max_num_threads();
if (stage->GetDimRange(dims) > num_thread) {
stage->Split(dims, gcd(stage->GetDimRange(dims), num_thread));
++dims;
int dims = stage->n_out_dims();
for (int i = 1; i < dims; i++) {
stage->Fuse(0, 1);
}

while (dims > 0 && stage->GetDimRange(dims - 1) * stage->GetDimRange(dims) < num_thread) {
stage->Fuse(dims - 1, dims);
--dims;
int num_thread = target.max_num_threads();
int num_block = 65535;
int prod_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
if (prod_size <= num_thread) {
stage->Bind(0, "threadIdx.x");
return;
}

stage->Bind(dims, "threadIdx.x");
--dims;

while (dims > 2) {
stage->Fuse(dims - 1, dims);
--dims;
int new_num_thread = gcd(prod_size, num_thread);
if (new_num_thread % 32 != 0) {
new_num_thread = MaxFactorLessThan(prod_size, num_thread);
}
std::string block_idx = "blockIdx.x";
for (int j = 0; dims >= 0; ++j) {
block_idx.back() = 'x' + j;
stage->Bind(dims, block_idx);
--dims;
if (new_num_thread == 1) LOG(FATAL) << "prod_size out of range: " << prod_size;

bool need_more_split = prod_size > new_num_thread * num_block ? true : false;
if (need_more_split) {
LOG(FATAL) << "prod_size out of range: " << prod_size << ", and new_num_thread is : " << new_num_thread;
} else {
CHECK_GT(prod_size, new_num_thread);
stage->Split(0, new_num_thread);
stage->Bind(0, "blockIdx.x");
stage->Bind(1, "threadIdx.x");
}
}

Expand Down

0 comments on commit 0fff689

Please sign in to comment.
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy