mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Fix AccDataType and CDataType
1. Fix AccDataType and CDataType 2. Remove indent 3. Align merge_transform for tutorial
This commit is contained in:
@@ -106,9 +106,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
@@ -205,9 +205,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -28,12 +28,7 @@ int main(int argc, char* argv[])
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
// TODO: FIXME
|
||||
#ifdef INSTRUCTION_SCHEDULE
|
||||
using CDataType = float;
|
||||
#else
|
||||
using CDataType = ck_tile::half_t;
|
||||
#endif
|
||||
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t M = 3328;
|
||||
|
||||
@@ -90,27 +90,27 @@ struct Gemm
|
||||
return [=](index_t block_1d_id) {
|
||||
constexpr index_t M01 = 2;
|
||||
constexpr index_t GroupNum = 4;
|
||||
|
||||
|
||||
const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
|
||||
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
|
||||
|
||||
|
||||
const auto group_id_x = block_1d_id % GroupNum;
|
||||
|
||||
|
||||
const auto remap_block_1d_id = (group_id_x <= big_group_num)
|
||||
? (group_id_x * group_size + block_1d_id / GroupNum)
|
||||
: (group_id_x * group_size + big_group_num - group_id_x);
|
||||
|
||||
|
||||
const index_t idx_M0 = remap_block_1d_id / N0;
|
||||
const index_t idx_N0 = remap_block_1d_id % N0;
|
||||
|
||||
|
||||
const index_t M0_mod_M01 = M0 % M01;
|
||||
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
|
||||
|
||||
|
||||
const index_t idx_M00 = idx_M0 / M01;
|
||||
const index_t idx_M01 = idx_M0 % M01;
|
||||
const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
|
||||
return make_multi_index(idx_N0_M01_local % M01_adapt + idx_M00 * M01, idx_N0_M01_local / M01_adapt);
|
||||
};
|
||||
#else
|
||||
@@ -120,7 +120,7 @@ struct Gemm
|
||||
multi_index<2> unmerged;
|
||||
unmerge.calculate_lower_index(unmerged, make_multi_index(block_id));
|
||||
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{}));
|
||||
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -21,6 +21,7 @@ struct GridGemm
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
using CDataType = typename Problem::CDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using CElementFunction = typename Problem::CElementFunction;
|
||||
|
||||
static constexpr auto kMPerBlock = Policy::kMPerBlock;
|
||||
@@ -110,7 +111,7 @@ struct GridGemm
|
||||
|
||||
using UniversalGemmProblem = UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmPipelineScheduler::Intrawave,
|
||||
|
||||
Reference in New Issue
Block a user