From 9ecde871a3be2dfd61eaa62c8345583d2c56037b Mon Sep 17 00:00:00 2001 From: mhYang Date: Tue, 25 Mar 2025 20:08:53 +0000 Subject: [PATCH] Fix AccDataType and CDataType 1. Fix AccDataType and CDataType 2. Remove indent 3. Align merge_transform for tutorial --- ...ipeline_agmem_bgmem_creg_default_policy.hpp | 8 ++++---- .../ck_tile/99_toy_example/02_gemm/gemm.cpp | 5 ----- .../ck_tile/99_toy_example/02_gemm/gemm.hpp | 18 +++++++++--------- .../99_toy_example/02_gemm/grid_gemm.hpp | 3 ++- 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp index 940c659031..6fdb65e26b 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -106,9 +106,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto a_lds_block_desc = transform_tensor_descriptor( a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( + make_tuple(make_merge_transform( make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( + make_merge_transform( make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -205,9 +205,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( + make_tuple(make_merge_transform( make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( + make_merge_transform( make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp index 0c91af3ce5..7aea4b376d 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp @@ -28,12 +28,7 @@ int main(int argc, char* argv[]) using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using AccDataType = float; -// TODO: FIXME -#ifdef INSTRUCTION_SCHEDULE - using CDataType = float; -#else using CDataType = ck_tile::half_t; -#endif ck_tile::index_t verification = 0; ck_tile::index_t M = 3328; diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp index 78341d97a0..b2ed93c6cf 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp @@ -90,27 +90,27 @@ struct Gemm return [=](index_t block_1d_id) { constexpr index_t M01 = 2; constexpr index_t GroupNum = 4; - + const auto group_size = integer_divide_ceil(M0 * N0, GroupNum); const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0); - + const auto group_id_x = block_1d_id % GroupNum; - + const auto remap_block_1d_id = (group_id_x <= big_group_num) ? (group_id_x * group_size + block_1d_id / GroupNum) : (group_id_x * group_size + big_group_num - group_id_x); - + const index_t idx_M0 = remap_block_1d_id / N0; const index_t idx_N0 = remap_block_1d_id % N0; - + const index_t M0_mod_M01 = M0 % M01; - + const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01; - + const index_t idx_M00 = idx_M0 / M01; const index_t idx_M01 = idx_M0 % M01; const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - + return make_multi_index(idx_N0_M01_local % M01_adapt + idx_M00 * M01, idx_N0_M01_local / M01_adapt); }; #else @@ -120,7 +120,7 @@ struct Gemm multi_index<2> unmerged; unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); - return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); }; #endif diff --git a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp index 0e3e1fef95..4e15ff2845 100644 --- a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp @@ -21,6 +21,7 @@ struct GridGemm using ADataType = typename Problem::ADataType; using BDataType = typename Problem::BDataType; using CDataType = typename Problem::CDataType; + using AccDataType = typename Problem::AccDataType; using CElementFunction = typename Problem::CElementFunction; static constexpr auto kMPerBlock = Policy::kMPerBlock; @@ -110,7 +111,7 @@ struct GridGemm using UniversalGemmProblem = UniversalGemmPipelineProblem