Fix AccDataType and CDataType

1. Fix AccDataType and CDataType
2. Remove indent
3. Align merge_transform for tutorial
This commit is contained in:
mhYang
2025-03-25 20:08:53 +00:00
parent 67072b3ba9
commit 9ecde871a3
4 changed files with 15 additions and 19 deletions

View File

@@ -106,9 +106,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(make_merge_transform(
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
@@ -205,9 +205,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(make_merge_transform(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));

View File

@@ -28,12 +28,7 @@ int main(int argc, char* argv[])
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
// TODO: FIXME
#ifdef INSTRUCTION_SCHEDULE
using CDataType = float;
#else
using CDataType = ck_tile::half_t;
#endif
ck_tile::index_t verification = 0;
ck_tile::index_t M = 3328;

View File

@@ -90,27 +90,27 @@ struct Gemm
return [=](index_t block_1d_id) {
constexpr index_t M01 = 2;
constexpr index_t GroupNum = 4;
const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
const auto group_id_x = block_1d_id % GroupNum;
const auto remap_block_1d_id = (group_id_x <= big_group_num)
? (group_id_x * group_size + block_1d_id / GroupNum)
: (group_id_x * group_size + big_group_num - group_id_x);
const index_t idx_M0 = remap_block_1d_id / N0;
const index_t idx_N0 = remap_block_1d_id % N0;
const index_t M0_mod_M01 = M0 % M01;
const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
const index_t idx_M00 = idx_M0 / M01;
const index_t idx_M01 = idx_M0 % M01;
const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_multi_index(idx_N0_M01_local % M01_adapt + idx_M00 * M01, idx_N0_M01_local / M01_adapt);
};
#else
@@ -120,7 +120,7 @@ struct Gemm
multi_index<2> unmerged;
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
};
#endif

View File

@@ -21,6 +21,7 @@ struct GridGemm
using ADataType = typename Problem::ADataType;
using BDataType = typename Problem::BDataType;
using CDataType = typename Problem::CDataType;
using AccDataType = typename Problem::AccDataType;
using CElementFunction = typename Problem::CElementFunction;
static constexpr auto kMPerBlock = Policy::kMPerBlock;
@@ -110,7 +111,7 @@ struct GridGemm
using UniversalGemmProblem = UniversalGemmPipelineProblem<ADataType,
BDataType,
CDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmPipelineScheduler::Intrawave,