diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp index a9f6dced9d..ecd9aa8d48 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp @@ -62,7 +62,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1); static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2); - static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock; + static constexpr index_t kflatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); @@ -170,9 +170,9 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp static_assert((MPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I0] && KPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I1]), "A block window has incorrect lengths for defined ALayout!"); - static_assert(Preshuffle // + static_assert(Preshuffle ? (NWarps == BsDramBlockWindowTmp{}.get_window_lengths()[I0] && - kflatKPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]) + kflatKPerWarp == BsDramBlockWindowTmp{}.get_window_lengths()[I1]) : (NPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I0] && KPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]), "B block window has incorrect lengths for defined BLayout!"); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp index 1a12eaa4fe..26a8c9ed44 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp @@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; + static constexpr index_t kDramLoadPackBytes = 128; + using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using ADataType = remove_cvref_t; @@ -33,14 +35,17 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy "ALayout must be RowMajor!"); static_assert(std::is_same_v, "BLayout must be ColumnMajor!"); - static_assert(is_any_of::value); - static_assert(is_any_of::value); + static_assert( + is_any_of::value); + static_assert( + is_any_of::value); static_assert(std::is_same_v); static_assert(std::is_same_v); - static constexpr auto WGAccess = std::is_same_v - ? WGAttrNumAccessEnum::Double - : WGAttrNumAccessEnum::Single; + static constexpr auto WGAccess = + std::is_same_v || std::is_same_v + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; static constexpr auto PackedSize = numeric_traits::PackedSize; using BlockGemmShape = typename Problem::BlockGemmShape; @@ -92,10 +97,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy static constexpr index_t KWarps = BlockWarps::at(I2); static constexpr index_t MIterPerWarp = MWarpTiles / MWarps; static constexpr index_t NIterPerWarp = NWarpTiles / NWarps; - static constexpr index_t KPerWarp = KPerBlock / KWarps; static constexpr index_t NPerWarp = NPerBlock / NWarps; static_assert(NWarps == 2, "NWarps == 2 for ping-pong!"); - static_assert(KWarpTiles == KWarps, "Wrong!"); static constexpr index_t warp_size = get_warp_size(); static constexpr index_t warp_num = BlockSize / warp_size; @@ -103,27 +106,35 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy static_assert(warp_num * warp_size == BlockSize, "Wrong!"); static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!"); - static constexpr index_t ElementSize = sizeof(ADataType); - static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize * PackedSize; // 16 - static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8 - static constexpr index_t K0 = KPerWarp / (K1 * K2); - static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!"); - static_assert(K0 == 1, "Wrong!"); + static_assert(sizeof(ADataType) == sizeof(ComputeDataType)); + static constexpr index_t ElementSize = sizeof(ComputeDataType); + static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize * PackedSize; + // We define kDramLoadPackElems as 128 for fp6 because K2 == 16, so in this way we have correct + // values for K1 (number of contiguous lanes in a row) + static constexpr index_t kDramLoadPackElems = + std::is_same_v + ? kDramLoadPackBytes + : kDramLoadPackBytes / sizeof(ComputeDataType) * PackedSize; + static constexpr index_t PacksPerLdsRow = std::min(kDramLoadPackElems, KPerBlock) / K2; + static constexpr index_t K1 = PacksPerLdsRow; + static constexpr index_t K0 = KPerBlock / (K1 * K2); + static_assert(K0 * K1 * K2 == KPerBlock, "Wrong!"); + + static constexpr index_t SwizzleFactor = WarpTileK / static_cast(WGAccess) / K2; CK_TILE_DEVICE static constexpr bool IsPreshuffle() { return Preshuffle; } CK_TILE_DEVICE static constexpr auto MakeADramTileDistribution() { - constexpr index_t M2 = warp_size / K1; // 8 - constexpr index_t M1 = warp_num; // 8 + constexpr index_t M2 = warp_size / K1; + constexpr index_t M1 = warp_num; constexpr index_t M0 = MPerBlock / M1 / M2; static_assert(M0 * M1 * M2 == MPerBlock, "wrong!"); return make_static_tile_distribution( ck_tile::tile_distribution_encoding< ck_tile::sequence<>, - ck_tile::tuple, // [123] 8 8 - ck_tile::sequence>, // 1 8 16 + ck_tile::tuple, ck_tile::sequence>, ck_tile::tuple, ck_tile::sequence<1, 2>>, // M0 M2,K1 ck_tile::tuple, ck_tile::sequence<2, 1>>, ck_tile::sequence<1, 2, 2>, // M0,K0,K2 @@ -134,36 +145,33 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy { if constexpr(Preshuffle) { - constexpr index_t K1_ = warp_size; // 64 - constexpr index_t K0_ = KPerBlock * WarpTileN / K1_ / K2; // 2 + constexpr index_t K1_ = warp_size; + constexpr index_t K0_ = KPerBlock * WarpTileN / K1_ / K2; static_assert(K0_ * K1_ * K2 == KPerBlock * WarpTileN, "wrong!"); - constexpr index_t N1 = warp_num / NWarps / K0_; // 2 - constexpr index_t N0 = NPerBlock / WarpTileN / N1 / NWarps; // 4 + constexpr index_t N1 = warp_num / NWarps / K0_; + constexpr index_t N0 = NPerBlock / WarpTileN / N1 / NWarps; static_assert(NWarps * N0 * N1 == NPerBlock / WarpTileN, "wrong!"); return make_static_tile_distribution( - tile_distribution_encoding< // - sequence<>, - tuple, // 2 [4] 2 - sequence>, // 2 64 16 - tuple, sequence<2>>, // NWarps,N1,K0 K1 - tuple, sequence<1>>, - sequence<1, 2>, // N0,K2 - sequence<1, 2>>{}); + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2>>, // NWarps,N1,K0 K1 + tuple, sequence<1>>, + sequence<1, 2>, // N0,K2 + sequence<1, 2>>{}); } else { - constexpr index_t N2 = warp_size / K1; // 8 - constexpr index_t N1 = warp_num / NWarps; // 4 - constexpr index_t N0 = NPerBlock / N1 / N2 / NWarps; // 4 + constexpr index_t N2 = warp_size / K1; + constexpr index_t N1 = warp_num / NWarps; + constexpr index_t N0 = NPerBlock / N1 / N2 / NWarps; static_assert(NWarps * N0 * N1 * N2 == NPerBlock, "wrong!"); return make_static_tile_distribution( - tile_distribution_encoding< // + tile_distribution_encoding< sequence<>, - tuple, // 2 [4] 4 8 - sequence>, // 1 8 16 + tuple, sequence>, tuple, sequence<1, 2>>, // NWarps,N1 N2,K1 tuple, sequence<3, 1>>, sequence<1, 2, 2>, // N0,K0,K2 @@ -182,7 +190,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy const index_t k_tiles = cols / (KWarps * K1 * K2); const auto col_lens = make_tuple(k_tiles, number{}, number{}, number{}); - constexpr index_t M1 = warp_size / static_cast(WGAccess) / K1; // 4 + constexpr index_t M1 = SwizzleFactor; const index_t M0 = integer_divide_ceil(rows, M1); const auto row_lens = make_tuple(M0, number{}); @@ -233,52 +241,64 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy template CK_TILE_DEVICE static constexpr auto MakeABLdsBlockDescriptor_() { - constexpr index_t M4 = warp_size / static_cast(WGAccess) / K1; // 4 - constexpr index_t M3 = static_cast(WGAccess); // 2 - constexpr index_t M2 = WarpTileM / M4 / M3; // 2 + constexpr index_t M4 = SwizzleFactor; + constexpr index_t M3 = warp_size / (M4 * K1); + constexpr index_t M2 = WarpTileM / M4 / M3; constexpr index_t M1 = (warp_num / warp_groups_) / M2; constexpr index_t M0 = MNPerBlock / M1 / M2 / M3 / M4; static_assert(M1 * M0 * M2 * M3 * M4 == MNPerBlock, "wrong!"); - constexpr index_t PadSize = 16; + constexpr index_t PadSize = SwizzleFactor * K2; + // Padding is needed between waves writing to LDS for a single mfma tile + // Example: instruction 16x128 8 bit + // 2 waves read elements from gmem to lds but they are consumed by a single + // wave when reading from LDS, so we need padding there but not between + // consecutive mfma tiles in LDS constexpr auto desc_0 = make_naive_tensor_descriptor( // - number_tuple{}, - number_tuple{}, + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), number{}, number<1>{}); constexpr auto desc_1 = transform_tensor_descriptor( desc_0, - make_tuple(make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), + make_tuple(make_pass_through_transform(number{}), make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), make_pass_through_transform(number{}), make_xor_transform(make_tuple(number{}, number{})), make_pass_through_transform(number{})), - container_concat(generate_tuple([](auto i) { return sequence{}; }, number<6>{}), - make_tuple(sequence<6, 7>{}), - make_tuple(sequence<8>{})), - container_concat(generate_tuple([](auto i) { return sequence{}; }, number<6>{}), - make_tuple(sequence<6, 7>{}), - make_tuple(sequence<8>{}))); - constexpr auto desc_2 = transform_tensor_descriptor( // + container_concat(generate_tuple([](auto i) { return sequence{}; }, number<5>{}), + make_tuple(sequence<5, 6>{}), + make_tuple(sequence<7>{})), + container_concat(generate_tuple([](auto i) { return sequence{}; }, number<5>{}), + make_tuple(sequence<5, 6>{}), + make_tuple(sequence<7>{}))); + + constexpr auto desc_2 = transform_tensor_descriptor( desc_1, make_tuple(make_merge_transform_v3_division_mod(number_tuple{}), - make_merge_transform_v3_division_mod(number_tuple{})), - make_tuple(sequence<3, 2, 0, 5, 6>{}, sequence<1, 4, 7, 8>{}), + make_merge_transform_v3_division_mod(number_tuple{})), + make_tuple(sequence<0, 2, 3, 4, 5>{}, sequence<1, 6, 7>{}), make_tuple(sequence<0>{}, sequence<1>{})); + return desc_2; } CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor() @@ -317,8 +337,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy else { constexpr index_t K1_ = warp_size / WarpTileN; // 4 - constexpr index_t K0_ = KPerWarp / K1_ / K2; // 2 - static_assert(K0_ * K1_ * K2 == KPerWarp, "wrong!"); + constexpr index_t K0_ = KPerBlock / K1_ / K2; // 2 + static_assert(K0_ * K1_ * K2 == KPerBlock, "wrong!"); constexpr index_t N2 = warp_size / K1_; // 16 constexpr index_t N1 = warp_num / NWarps / K0_; // 2 @@ -342,15 +362,17 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy CK_TILE_DEVICE static constexpr index_t GetSmemSizeA() { + constexpr index_t sizeofType = + std::is_same_v ? 16 : sizeof(AComputeDataType); constexpr index_t desc_size = MakeALdsBlockDescriptor().get_element_space_size(); - return integer_least_multiple(sizeof(typename Problem::ADataType) * desc_size / PackedSize, - 16); + return integer_least_multiple(sizeofType * desc_size / PackedSize, 16); } CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { + constexpr index_t sizeofType = + std::is_same_v ? 16 : sizeof(BComputeDataType); constexpr index_t desc_size = MakeBLdsBlockDescriptor().get_element_space_size(); - return integer_least_multiple(sizeof(typename Problem::BDataType) * desc_size / PackedSize, - 16); + return integer_least_multiple(sizeofType * desc_size / PackedSize, 16); } CK_TILE_DEVICE static constexpr index_t GetSmemSize() @@ -371,8 +393,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy // TODO: Fix for transpose constexpr auto wg_attr_num_access = WGAccess; - using WarpGemm = WarpGemmDispatcher; - using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + using BlockGemmPolicy = + BlockGemmARegBRegCRegV1CustomPolicy; return BlockGemmARegBRegCRegEightWavesV1{}; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp index 08a979abfd..29b9373c42 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp @@ -63,8 +63,11 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< DstBlockTile& dst_block_tile, SrcTileWindow& lds_tile_window) const { + // swizzle factor limitation + using static_move_ys = + std::conditional_t, false_type, true_type>; lds_tile_window.set_bottom_tensor_view_data_ptr(smem); - lds_tile_window.load(dst_block_tile, number<-1>{}, true_type{}, true_type{}); + lds_tile_window.load(dst_block_tile, number<-1>{}, true_type{}, static_move_ys{}); } template @@ -72,6 +75,9 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< DstBlockTile& dst_block_tile, SrcTileWindow& lds_tile_window) const { + // swizzle factor limitation + using static_move_ys = + std::conditional_t, false_type, true_type>; lds_tile_window.set_bottom_tensor_view_data_ptr(smem); static_for_product, number>{}( [&](auto nIter, auto kIter) { @@ -80,7 +86,7 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< dst_block_tile[nIter][kIter], number<-1>{}, true_type{}, - true_type{}); + static_move_ys{}); }); } diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp index ec5db8afdd..519b7afcd3 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp @@ -29,8 +29,8 @@ struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy using ComputeDataType = AComputeDataType; static_assert(std::is_same_v, "Wrong!"); static_assert(std::is_same_v, "Wrong!"); - static_assert(is_any_of::value); - static_assert(is_any_of::value); + static_assert(is_any_of::value); + static_assert(is_any_of::value); static_assert(std::is_same_v); static_assert(std::is_same_v); @@ -79,7 +79,6 @@ struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy static constexpr index_t KPerWarp = KPerBlock / KWarps; static constexpr index_t NPerWarp = NPerBlock / NWarps; static_assert(NWarps == 2, "NWarps == 2 for ping-pong!"); - static_assert(KWarpTiles == KWarps, "Wrong!"); static constexpr index_t warp_size = get_warp_size(); static constexpr index_t warp_num = BlockSize / warp_size; @@ -92,7 +91,6 @@ struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8 static constexpr index_t K0 = KPerWarp / (K1 * K2); static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!"); - static_assert(K0 == 1, "Wrong!"); CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; } CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 6a427aa471..23684a5521 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -311,20 +311,52 @@ using CompAsyncConfig16x16x128 = std::tuple; template -using CompAsyncEightWavesConfig = std::tuple; +using CompAsyncEightWavesConfig4Bit = std::tuple; + +template +using CompAsyncEightWavesConfig8Bit = std::tuple; + +template +using CompAsyncEightWavesConfig16Bit = std::tuple; using KernelTypesCompAsync = ::testing::Types, CompAsyncConfig, @@ -339,7 +371,11 @@ using KernelTypesCompAsync16x16x128 = ::testing::Types>; using KernelTypesCompAsyncEightWaves = - ::testing::Types>; + ::testing::Types, + CompAsyncEightWavesConfig8Bit, + CompAsyncEightWavesConfig4Bit, + CompAsyncEightWavesConfig16Bit, + CompAsyncEightWavesConfig16Bit>; // clang-format off using KernelTypesCompV6 = ::testing::Types< diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 56dfde3509..d1dc7a1813 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -63,8 +63,11 @@ constexpr ck_tile::index_t get_k_warp_tile() return 16; #endif #else - if constexpr(PipelineType == GemmPipelineType::CompAsyncEightWaves) + if constexpr(PipelineType == GemmPipelineType::CompAsyncEightWaves && sizeof(PrecType) == 1) return 128; + else if constexpr(PipelineType == GemmPipelineType::CompAsyncEightWaves && + sizeof(PrecType) == 2) + return 32; // CompAsyncConfig16x16x128 else if constexpr(PipelineType == GemmPipelineType::CompAsync && M_Warp_Tile == 16) return 128;