mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
add debuging code and format
This commit is contained in:
@@ -45,7 +45,7 @@ struct fmoe_ // traits, ugly name, only used for internal
|
||||
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
|
||||
using WarpPerBlock_1 = ck_tile::sequence<1, 1, 4>;//ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpPerBlock_1 = ck_tile::sequence<1, 1, 4>; // ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
|
||||
@@ -83,13 +83,43 @@ void topid_unique_gen(
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m, int n)
|
||||
{
|
||||
std::cout << std::endl;
|
||||
for(int i = 0; i < m; i++)
|
||||
{
|
||||
std::cout << "Line " << i << "\t";
|
||||
for(int j = 0; j < n; j++)
|
||||
{
|
||||
std::cout << ck_tile::type_convert<float>(data(i, j)) << "\t";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
template <typename IndexType>
|
||||
void output_matrix_3d(ck_tile::HostTensor<IndexType>& data, int M, int N, int J)
|
||||
{
|
||||
std::cout << std::endl;
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
std::cout << "experts: " << m << " Line: " << n << "\t";
|
||||
for(int j = 0; j < J; j++)
|
||||
{
|
||||
std::cout << ck_tile::type_convert<float>(data(m, n, j)) << "\t";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("t", "128", "num input tokens")
|
||||
.insert("e", "32", "num of experts")
|
||||
.insert("k", "5", "topk")
|
||||
.insert("k", "2", "topk")
|
||||
.insert("h", "8192", "hidden_size of this model")
|
||||
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
|
||||
@@ -112,7 +142,7 @@ auto create_args(int argc, char* argv[])
|
||||
"0",
|
||||
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
||||
.insert("init",
|
||||
"2",
|
||||
"1",
|
||||
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
|
||||
"normalized(slow)")
|
||||
.insert("seed", "11939", "seed used to do random")
|
||||
@@ -176,9 +206,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
auto api_str = [&]() {
|
||||
return std::string("moeg");
|
||||
}();
|
||||
auto api_str = [&]() { return std::string("moeg"); }();
|
||||
|
||||
auto stride_str = [&]() {
|
||||
if(stride == hidden_size)
|
||||
@@ -245,7 +273,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
|
||||
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
|
||||
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed, true}(
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f, seed, true}(
|
||||
topk_weight_host);
|
||||
}
|
||||
else if(init == 2)
|
||||
@@ -333,116 +361,122 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(static_cast<double>(ms) * 1e-3) / 1e12;
|
||||
};
|
||||
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m);
|
||||
|
||||
// output_matrix_2d(a_host, tokens, hidden_size);
|
||||
std::cout << sorted_token_ids_host << std::endl;
|
||||
std::cout << num_sorted_tiles_host << std::endl;
|
||||
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
|
||||
std::cout << sorted_expert_ids_host << std::endl;
|
||||
// std::cout << topk_weight_host << std::endl;
|
||||
|
||||
// std::cout << sorted_weight_host << std::endl;
|
||||
// done, preparing GPU buffer
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host);
|
||||
|
||||
// manually clear output buffer for atomic
|
||||
o_buf.SetZero();
|
||||
//
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
|
||||
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
|
||||
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
|
||||
|
||||
fused_moegemm_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
||||
g_perm_buf.GetDeviceBuffer(),
|
||||
d_perm_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
sorted_weight_buf.GetDeviceBuffer(),
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride,
|
||||
max_num_tokens_padded};
|
||||
|
||||
float ave_time = fused_moegemm(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host.mData[0],
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
block_m);
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
|
||||
// done, preparing GPU buffer
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host);
|
||||
|
||||
// manually clear output buffer for atomic
|
||||
o_buf.SetZero();
|
||||
//
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
|
||||
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
|
||||
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
|
||||
|
||||
fused_moegemm_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
||||
g_perm_buf.GetDeviceBuffer(),
|
||||
d_perm_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
sorted_weight_buf.GetDeviceBuffer(),
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride,
|
||||
max_num_tokens_padded};
|
||||
|
||||
float ave_time = fused_moegemm(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
auto [rtol, atol] = get_elimit<ADataType>();
|
||||
pass &= ck_tile::check_err(
|
||||
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
|
||||
return pass;
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
auto [rtol, atol] = get_elimit<ADataType>();
|
||||
pass &= ck_tile::check_err(
|
||||
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -213,9 +213,9 @@ struct FusedMoeGemmGlKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
//constexpr index_t block_m = BlockShape::Block_M0;
|
||||
// constexpr index_t block_m = BlockShape::Block_M0;
|
||||
int max_num_tokens_padded = hargs.max_num_tokens_padded;
|
||||
//hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
|
||||
// hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
|
||||
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
|
||||
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ struct FusedMoeGemmHostArgs
|
||||
index_t num_experts; // number of groups
|
||||
index_t topk; // need this?
|
||||
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
|
||||
};
|
||||
|
||||
|
||||
@@ -124,9 +124,9 @@ struct FusedMoeGemmPipeline_General
|
||||
index_t hidden_size,
|
||||
index_t intermediate_size)
|
||||
{
|
||||
ignore = d_window_;
|
||||
ignore = hidden_size;
|
||||
ignore = intermediate_size;
|
||||
ignore = d_window_;
|
||||
ignore = hidden_size;
|
||||
ignore = intermediate_size;
|
||||
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
|
||||
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
|
||||
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
|
||||
@@ -191,12 +191,13 @@ struct FusedMoeGemmPipeline_General
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc, a_lds_win, g_dram_block);
|
||||
}
|
||||
// relu
|
||||
const auto activation = ck_tile::element_wise::Gelu{};
|
||||
tile_elementwise_inout(activation, s_acc, s_acc);
|
||||
#if 0
|
||||
#if 1
|
||||
PrintMem(s_acc);
|
||||
#endif
|
||||
// relu
|
||||
const auto activation = ck_tile::element_wise::Gelu{};
|
||||
tile_elementwise_inout(activation, s_acc, s_acc);
|
||||
|
||||
// move sacc to LDS
|
||||
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
|
||||
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
|
||||
@@ -238,7 +239,7 @@ struct FusedMoeGemmPipeline_General
|
||||
index_t iCounter1 = n1_loops - 1;
|
||||
while(iCounter1 > 0)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
clear_tile(o_acc);
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc, y, d);
|
||||
block_sync_lds();
|
||||
@@ -253,7 +254,7 @@ struct FusedMoeGemmPipeline_General
|
||||
}
|
||||
// tail
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
clear_tile(o_acc);
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc, y, d);
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
{
|
||||
using WG = decltype(GetWarpGemm0<Problem>());
|
||||
using S_ = typename Problem::BlockShape;
|
||||
static_assert(S_::WarpPerBlock_N0==4);
|
||||
static_assert(S_::WarpPerBlock_N0 == 4);
|
||||
constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<S_::WarpPerBlock_M0>,
|
||||
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
|
||||
@@ -240,13 +240,14 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
|
||||
constexpr auto y_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>, sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
constexpr auto y_outer_dstr_enc =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
|
||||
sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
|
||||
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
|
||||
@@ -260,13 +261,14 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
|
||||
constexpr auto d_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>, sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
constexpr auto d_outer_dstr_enc =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>,
|
||||
sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
|
||||
constexpr auto d_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
d_outer_dstr_enc, typename WarpGemm::BWarpDstrEncoding{});
|
||||
@@ -356,8 +358,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
1>>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
|
||||
@@ -396,8 +398,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
1>>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
|
||||
|
||||
@@ -52,16 +52,16 @@ struct BlockGemmARegBRegCRegV2
|
||||
// M->N Warp
|
||||
// constexpr auto a_block_outer_dstr_encoding =
|
||||
// tile_distribution_encoding<sequence<NWarp>,
|
||||
// tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
// tuple<sequence<1, 0>>,
|
||||
// tuple<sequence<MIterPerWarp, MWarp>,
|
||||
// sequence<KIterPerWarp>>, tuple<sequence<1, 0>>,
|
||||
// tuple<sequence<1, 0>>,
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{};
|
||||
|
||||
// constexpr auto b_block_outer_dstr_encoding =
|
||||
// tile_distribution_encoding<sequence<MWarp>,
|
||||
// tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
// tuple<sequence<0, 1>>,
|
||||
// tuple<sequence<NIterPerWarp, NWarp>,
|
||||
// sequence<KIterPerWarp>>, tuple<sequence<0, 1>>,
|
||||
// tuple<sequence<0, 1>>,
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{};
|
||||
|
||||
Reference in New Issue
Block a user