diff --git a/CMakeLists.txt b/CMakeLists.txt index b28a6d9127..bce13a2514 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -503,10 +503,6 @@ include_directories(BEFORE ) SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") -if(BUILD_DEV) - add_compile_options(-Werror) - add_compile_options(-Weverything) -endif() message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 93fd306e98..273f383da0 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + # -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 3425da6712..e8157a7a67 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -117,6 +117,10 @@ int run_gemm_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + // ck_tile::FillMonotonicSeq{}(a_m_k); + // ck_tile::FillMonotonicSeq{}(b_k_n); + // ck_tile::FillConstant{1.f}(a_m_k); + // ck_tile::FillConstant{1.f}(b_k_n); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 335911860a..404b59f1d5 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -211,15 +211,16 @@ struct FillNormalDistributionIntegerValue template struct FillMonotonicSeq { - T init_value_{0}; + T init_value_{-1024}; T step_{1}; - template void operator()(ForwardIter first, ForwardIter last) const { - std::generate(first, last, [=, n = init_value_]() mutable { + T step_start = init_value_; + std::generate(first, last, [&, n = init_value_]() mutable { auto tmp = n; n += step_; + if (n > step_start + 2047) {step_start += step_; n = step_start;} return tmp; }); } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index 433f5c0dcc..6caf67a2a6 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -42,9 +42,6 @@ struct BlockGemmASmemBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - // if(threadIdx.x == 0 && blockIdx.x==0) { - // printf("MPerBlock %d NPerBlock %d KPerBlock %d \n", MPerBlock, NPerBlock, KPerBlock); - // } constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -56,12 +53,12 @@ struct BlockGemmASmemBSmemCRegV1 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; - constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + // constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + // constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + // constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() % NWarp; + // const index_t iMWarp = get_warp_id() / NWarp; + // const index_t iNWarp = get_warp_id() % NWarp; // if(threadIdx.x == 0 && blockIdx.x==0) { // printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter); @@ -69,91 +66,69 @@ struct BlockGemmASmemBSmemCRegV1 // MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8 - // construct A-warp-window auto a_warp_window_tmp = make_tile_window( a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - -#if 0 // FIXME: using array will cause register spill - array, MIterPerWarp> a_warp_windows{ - {a_warp_window_tmp}}; - - for(index_t mIter = 0; mIter < MIterPerWarp; mIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - // construct B-warp-window + make_tuple(MPerBlock, KPerBlock), + {0, 0}, + Policy::template MakeALDSTileDistribution()); auto b_warp_window_tmp = make_tile_window( b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, - make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + make_tuple(NPerBlock, KPerBlock), + {0, 0}, + Policy::template MakeBLDSTileDistribution()); -#if 0 // FIXME: using array will cause register spill - array, NIterPerWarp> b_warp_windows{ - {b_warp_window_tmp}}; + auto a_block_tensor = load_tile(a_warp_window_tmp); + auto b_block_tensor = load_tile(b_warp_window_tmp); - for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) - { - for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) - { - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - } - } -#else - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + // if (threadIdx.x == 0) { + // printf("0\n"); + // constexpr auto span_2d = decltype(a_block_tensor)::get_distributed_spans(); + // sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + // sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // printf("%f %f,", type_convert(a_block_tensor(i_j_idx)), type_convert(b_block_tensor(i_j_idx))); + // }); + // printf("\n"); + // }); + // } + // __syncthreads(); + using AWarpDstr = typename WG::AWarpDstr; + using BWarpDstr = typename WG::BWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); -#endif - - using CWarpDstr = typename WG::CWarpDstr; + using AWarpTensor = typename WG::AWarpTensor; + using BWarpTensor = typename WG::BWarpTensor; using CWarpTensor = typename WG::CWarpTensor; + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -173,6 +148,36 @@ struct BlockGemmASmemBSmemCRegV1 }); }); }); + // constexpr auto c_warp_y_lengths = + // to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + // constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + // // hot loop: + // static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + // static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // // read A warp tensor from A block window + + // static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // // read B warp tensor from B Block window + // // const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // // read C warp tensor from C block tensor + // CWarpTensor c_warp_tensor; + + // c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + // merge_sequences(sequence{}, c_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // // warp GEMM + // WG{}(c_warp_tensor, a_warp_tensor(mIter, kIter), b_warp_tensor(nIter, kIter)); + + // // write C warp tensor into C block tensor + // c_block_tensor.set_y_sliced_thread_data( + // merge_sequences(sequence{}, c_warp_y_index_zeros), + // merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + // c_warp_tensor.get_thread_buffer()); + // }); + // }); + // }); } CK_TILE_DEVICE static constexpr auto MakeCBlockTile() @@ -217,5 +222,72 @@ struct BlockGemmASmemBSmemCRegV1 return c_block_tensor; } }; + // construct A-warp-window + // auto a_warp_window_tmp = make_tile_window( + // a_block_window.get_bottom_tensor_view(), + // make_tuple(number{}, number{}), + // a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + // make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); +// #if 0 // FIXME: using array will cause register spill +// array, MIterPerWarp> a_warp_windows{ +// {a_warp_window_tmp}}; +// for(index_t mIter = 0; mIter < MIterPerWarp; mIter++) +// { +// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) +// { +// move_tile_window(a_warp_windows(mIter)(kIter), +// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); +// } +// } +// #else +// statically_indexed_array< +// statically_indexed_array, +// MIterPerWarp> +// a_warp_windows; + +// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { +// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { +// a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + +// move_tile_window(a_warp_windows(mIter)(kIter), +// {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); +// }); +// }); +// #endif + + // construct B-warp-window +// auto b_warp_window_tmp = make_tile_window( +// b_block_window.get_bottom_tensor_view(), +// make_tuple(number{}, number{}), +// b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, +// make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +// #if 0 // FIXME: using array will cause register spill +// array, NIterPerWarp> b_warp_windows{ +// {b_warp_window_tmp}}; + +// for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) +// { +// for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) +// { +// move_tile_window(b_warp_windows(nIter)(kIter), +// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); +// } +// } +// #else +// statically_indexed_array< +// statically_indexed_array, +// NIterPerWarp> +// b_warp_windows; + +// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { +// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { +// b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + +// move_tile_window(b_warp_windows(nIter)(kIter), +// {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); +// }); +// }); +// #endif } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index f510355aad..124e3c5c14 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -40,7 +40,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); } #else - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 2, 2); + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2); // return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); #endif } @@ -55,6 +55,96 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy static_assert(false, "Unsupported data type configuration for GEMM warp execution."); } } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALDSTileDistribution() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + static_assert(false, "Unsupported tensor_layout right now."); + } + else + { + //Number{}, Number{}, Number{}))), + constexpr index_t K2 = 16 / sizeof(ADataType); + constexpr index_t K1 = 2; + constexpr index_t K0 = KPerBlock / K1 / K2; + //Number{}, Number{}, Number{}))), + constexpr index_t M2 = 32; // MPERXDL + constexpr index_t M1 = 2; //MWAVE + // coalesce reading for each blocks + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + } + else + { + static_assert(false, "Unsupported shape right now."); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLDSTileDistribution() + { + using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + static_assert(false, "Unsupported tensor_layout right now."); + } + else + { + //Number{}, Number{}, Number{}))), + constexpr index_t K2 = 16 / sizeof(BDataType); + constexpr index_t K1 = 2; + constexpr index_t K0 = KPerBlock / K1 / K2; + //Number{}, Number{}, Number{}))), + constexpr index_t N2 = 32; // MPERXDL + constexpr index_t N1 = 2; //MWAVE + // coalesce reading for each blocks + if constexpr(get_warp_size() % (N2 * K0) == 0) + { + static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + } + else + { + static_assert(false, "Unsupported shape right now."); + } + } + } + }; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index c0817e736b..f4cb58eff0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -133,7 +133,16 @@ struct GemmPipelineAGmemBGmemCRegV1 // global read 0 auto a_block_tile = load_tile(a_copy_dram_window); auto b_block_tile = load_tile(b_copy_dram_window); - + // if (threadIdx.x == 0) { + // constexpr auto span_2d = decltype(a_block_tile)::get_distributed_spans(); + // sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + // sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // printf("%f,", type_convert(a_block_tile(i_j_idx))); + // }); + // printf("\n"); + // }); + // } { // move to 1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); @@ -170,7 +179,17 @@ struct GemmPipelineAGmemBGmemCRegV1 store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile)); } } - + // __syncthreads(); + // if (threadIdx.x == 0) { + // for (int j = 0; j < 256; j++) { + // for(int i = 0; i < 32; i++) { + // int ik0 = i /8; + // int ik1 = i % 8; + // printf("%f,", type_convert(p_b_lds[ik1 + j * 8 + ik0 * 8 * 256])); + // } + // printf("\n"); + // } + // } index_t iCounter = num_loop - 1; while(iCounter > 0) { @@ -219,6 +238,17 @@ struct GemmPipelineAGmemBGmemCRegV1 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } + // if (threadIdx.x == 0) { + // constexpr auto span_2d = decltype(c_block_tile)::get_distributed_spans(); + // sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { + // sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // if(abs(type_convert(c_block_tile(i_j_idx)) - 32) > 0.1) + // printf("%d %f,", threadIdx.x, type_convert(c_block_tile(i_j_idx))); + // }); + // printf("\n"); + // }); + // } return c_block_tile; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 04091480d1..146440a9d7 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -54,7 +54,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + make_tuple(number<(kMPerBlock) * 8>{}, number<8>{}, number<1>{}), number<8>{}, number<1>{}); @@ -77,7 +77,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + make_tuple(number<(kNPerBlock) * 8>{}, number<8>{}, number<1>{}), number<8>{}, number<1>{}); @@ -130,74 +130,74 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } #elif 1 // fake XOR - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; + // template + // CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + // { + // using namespace ck_tile; - using ADataType = remove_cvref_t; + // using ADataType = remove_cvref_t; - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + // constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + // constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( - make_tuple(number{}, number<2>{}, number{}), - number{}); + // constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( + // make_tuple(number{}, number<2>{}, number{}), + // number{}); - constexpr index_t kK1 = 16 / sizeof(ADataType); + // constexpr index_t kK1 = 16 / sizeof(ADataType); - constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( - a_lds_block_desc_d1_d2_d3, - make_tuple( - make_xor_transform(make_tuple(number{}, number{}), kK1), - make_pass_through_transform(2)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{})); + // constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( + // a_lds_block_desc_d1_d2_d3, + // make_tuple( + // make_xor_transform(make_tuple(number{}, number{}), kK1), + // make_pass_through_transform(2)), + // make_tuple(sequence<0, 2>{}, sequence<1>{}), + // make_tuple(sequence<0, 2>{}, sequence<1>{})); - constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( - a_lds_block_desc_d4_d5_d6, - make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), - make_pass_through_transform(kKPerBlock)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + // constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( + // a_lds_block_desc_d4_d5_d6, + // make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), + // make_pass_through_transform(kKPerBlock)), + // make_tuple(sequence<0, 1>{}, sequence<2>{}), + // make_tuple(sequence<0>{}, sequence<1>{})); - return a_lds_block_desc_m_k; - } + // return a_lds_block_desc_m_k; + // } - // fake XOR - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - using namespace ck_tile; + // // fake XOR + // template + // CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + // { + // using namespace ck_tile; - using BDataType = remove_cvref_t; + // using BDataType = remove_cvref_t; - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + // constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + // constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( - make_tuple(number{}, number<2>{}, number{}), - number{}); + // constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( + // make_tuple(number{}, number<2>{}, number{}), + // number{}); - constexpr index_t kK1 = 16 / sizeof(BDataType); + // constexpr index_t kK1 = 16 / sizeof(BDataType); - constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( - b_lds_block_desc_d1_d2_d3, - make_tuple( - make_xor_transform(make_tuple(number{}, number{}), kK1), - make_pass_through_transform(2)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{})); + // constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( + // b_lds_block_desc_d1_d2_d3, + // make_tuple( + // make_xor_transform(make_tuple(number{}, number{}), kK1), + // make_pass_through_transform(2)), + // make_tuple(sequence<0, 2>{}, sequence<1>{}), + // make_tuple(sequence<0, 2>{}, sequence<1>{})); - constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( - b_lds_block_desc_d4_d5_d6, - make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), - make_pass_through_transform(kKPerBlock)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + // constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + // b_lds_block_desc_d4_d5_d6, + // make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), + // make_pass_through_transform(kKPerBlock)), + // make_tuple(sequence<0, 1>{}, sequence<2>{}), + // make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc_n_k; - } + // return b_lds_block_desc_n_k; + // } #endif template