diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 0e0ee9dbd8..667bb80ce9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -217,17 +217,17 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 ////////////// global window & register ///////////////// // A DRAM tile window for load auto a_copy_dram_window = - make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); // B DRAM tile window for load auto b_copy_dram_window = - make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); // A register tile for global load constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution(); @@ -317,25 +317,25 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 BLdsTile b_block_tile1; auto a_lds_ld_window0 = - make_tile_window_linear(a_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); auto a_lds_ld_window1 = - make_tile_window_linear(a_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); auto b_lds_ld_window0 = - make_tile_window_linear(b_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); auto b_lds_ld_window1 = - make_tile_window_linear(b_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index e528847438..f6920f1c57 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -17,56 +17,6 @@ namespace ck_tile { struct GemmPipelineAgBgCrCompV4DefaultPolicy : public UniversalGemmBasePolicy { - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{} / KPack, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackB(); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(kNPerBlock)*KPack>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return b_lds_block_desc; - } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index b555cf75e0..6890cf2f64 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -19,6 +19,245 @@ struct UniversalGemmBasePolicy static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked; static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked; + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using ADataType = remove_cvref_t; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + /** + * @brief Create LDS block descriptor for B tensor. + * + * @tparam Problem Gemm pipeline problem. + * @return B tensor LDS block descriptor. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + // using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + +#if 1 + // if constexpr(std::is_same_v) + { + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + BK0 * number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + BK0 * number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } +#else + else // B is Row Major + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeB(); + using TileEncodingPattern = TileDistributionEncodingPattern2D; + + constexpr auto BK0 = number{}; + constexpr auto BK1 = number{}; + // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N0 = TileEncodingPattern::X0; + constexpr auto N1 = NPerBlock / N0; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto NPerXdl = number{}; + + // constexpr auto KThreadWrite = + // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0 / KThreadRead; + + constexpr auto kfold = + (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1 * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + BK1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + // b_lds_block_desc_unmerged, + // make_tuple(make_merge_transform_v3_division_mod( + // make_tuple(number{}, + // number{}, + // number{}, + // number{})), + // make_merge_transform_v3_division_mod( + // make_tuple(number{}, number{}, number{})), + // make_pass_through_transform(BK1)), + // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), + // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + BK1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // return b_lds_block_desc_bk0_n_bk1; + return b_lds_block_desc_kn; + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( + // make_tuple(BK0, number{}, number{}), + // make_tuple(number{}, number{}, number<1>{}), + // number{}, + // number<1>{}); + + // constexpr auto b_lds_block_desc = transform_tensor_descriptor( + // b_lds_block_desc_bk0_n_bk1, + // make_tuple(make_pass_through_transform(number{}), + // make_merge_transform_v3_division_mod(make_tuple(BK0, + // number{}))), + // make_tuple(sequence<1>{}, sequence<0, 2>{}), + // make_tuple(sequence<0>{}, sequence<1>{})); + + // return b_lds_block_desc; + } +#endif + } + /** * @brief Get the maximum global memory vector load size. * @@ -301,7 +540,7 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr auto a_lds_desc = Derived::template MakeALdsBlockDescriptor(); + constexpr auto a_lds_desc = MakeALdsBlockDescriptor(); constexpr index_t smem_size_a = integer_least_multiple( sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16); return smem_size_a; @@ -310,7 +549,7 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr auto b_lds_desc = Derived::template MakeBLdsBlockDescriptor(); + constexpr auto b_lds_desc = MakeBLdsBlockDescriptor(); constexpr index_t smem_size_b = integer_least_multiple( sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16); return smem_size_b; @@ -330,245 +569,6 @@ struct UniversalGemmBasePolicy struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy { - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using ADataType = remove_cvref_t; - - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } - - /** - * @brief Create LDS block descriptor for B tensor. - * - * @tparam Problem Gemm pipeline problem. - * @return B tensor LDS block descriptor. - */ - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - // using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; - - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - -#if 1 - // if constexpr(std::is_same_v) - { - constexpr index_t KPack = GetSmemPackB(); - constexpr auto BK0 = number{}; - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple( - BK0 * number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - BK0 * number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; - } -#else - else // B is Row Major - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = TileDistributionEncodingPattern2D; - - constexpr auto BK0 = number{}; - constexpr auto BK1 = number{}; - // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N0 = TileEncodingPattern::X0; - constexpr auto N1 = NPerBlock / N0; - - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr auto NPerXdl = number{}; - - // constexpr auto KThreadWrite = - // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto KThreadWrite = TileEncodingPattern::Y2; - constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; - constexpr auto K0PerThreadRead = BK0 / KThreadRead; - - constexpr auto kfold = - (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1 * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - BK1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - // b_lds_block_desc_unmerged, - // make_tuple(make_merge_transform_v3_division_mod( - // make_tuple(number{}, - // number{}, - // number{}, - // number{})), - // make_merge_transform_v3_division_mod( - // make_tuple(number{}, number{}, number{})), - // make_pass_through_transform(BK1)), - // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), - // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - BK1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - // return b_lds_block_desc_bk0_n_bk1; - return b_lds_block_desc_kn; - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( - // make_tuple(BK0, number{}, number{}), - // make_tuple(number{}, number{}, number<1>{}), - // number{}, - // number<1>{}); - - // constexpr auto b_lds_block_desc = transform_tensor_descriptor( - // b_lds_block_desc_bk0_n_bk1, - // make_tuple(make_pass_through_transform(number{}), - // make_merge_transform_v3_division_mod(make_tuple(BK0, - // number{}))), - // make_tuple(sequence<1>{}, sequence<0, 2>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - - // return b_lds_block_desc; - } -#endif - } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() {