diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp index 8141d99286..744c844040 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -71,7 +71,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - constexpr bool transposed_warp_gemm = false; + constexpr bool transposed_warp_gemm = true; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; @@ -85,6 +85,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s CodegenGemmShape, CodegenGemmTraits, QuantGroupSize, + transposed_warp_gemm, ComputeDataType, ck_tile::GemmPipelineScheduler::Intrawave, has_hot_loop_v, diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp index 0690c4884f..13c416110a 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_preshuffle.cpp @@ -85,6 +85,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s CodegenGemmShape, CodegenGemmTraits, QuantGroupSize, + transposed_warp_gemm, ComputeDataType, ck_tile::GemmPipelineScheduler::Intrawave, has_hot_loop_v, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index c42874ca55..87772f78fc 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -262,9 +262,17 @@ using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl>>; +using WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed = + WarpGemmImpl>>; + using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< WarpGemmAttributeMfma>>; +using WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed = + WarpGemmImpl>>; + using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, 2>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 36a9955912..1f8b4f8adc 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -385,6 +385,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution static constexpr index_t kN = Impl::kM; static constexpr index_t kK = Impl::kK; static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t kCMLane = Impl::kCMLane; CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index d50b208946..5021fb9907 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -97,6 +97,7 @@ template<> struct WarpGemmDispatcher struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; @@ -104,9 +105,9 @@ template<> struct WarpGemmDispatcher struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; +template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; - template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; }; template<> struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; }; diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index d6921208c7..17ef73107b 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -158,6 +158,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + static constexpr bool TransposeC = Problem::TransposeC; }; public: @@ -359,63 +360,181 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase if constexpr(Traits::PreshuffleQuant) { - // A view is created on top of the preshuffled AQ, where each row of the - // view is composed of a row from a warp tile within an AQ block tile. - // Multiple warp tile rows that belong to the same block tile are laid - // out as consecutive rows. - // - // When we need to multiply a C warp tile with an AQ warp tile, thread 0 - // in the warp will load AQ_warp_tile[0], thread 1 will load - // AQ_warp_tile[1], and so on, up to thread 63, which will load - // AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS in - // this context, but we use cross-lane operations to access the data. - // (Cross-lane operations are faster than using LDS.) - // - // Note that when the size of the AQ warp tile is smaller than the warp - // size, you need to pad the rows in the view to ensure that each thread - // can read one element. - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - constexpr uint32_t kTileRowsOfCPerThread = 4; + if constexpr(Traits::TransposeC) // transposed C + { + static_assert(false, + "It is not supported yet to enable both Preshuffle."); + // TODO: + // A new tile distribution is needed for the Preshuffle and + // Transpose combination. For instance, with mnk at 16x16x32, lanes + // 0-15, 16-31, 32-47, and 48-63 must load the same elements of AQ. + } + else + { + // A view is created on top of the preshuffled AQ, where each row of + // the view is composed of a row from a warp tile within an AQ block + // tile. Multiple warp tile rows that belong to the same block tile + // are laid out as consecutive rows. + // + // When we need to multiply a C warp tile with an AQ warp tile, + // thread 0 in the warp will load AQ_warp_tile[0], thread 1 will + // load AQ_warp_tile[1], and so on, up to thread 63, which will load + // AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS + // in this context, but we use cross-lane operations to access the + // data. (Cross-lane operations are faster than using LDS.) + // + // Note that when the size of the AQ warp tile is smaller than the + // warp size, you need to pad the rows in the view to ensure that + // each thread can read one element. + constexpr auto tbuf_offset = number< + typename CBlockTensor::ThreadTensorDesc{}.calculate_offset( + merge_sequences(sequence{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + constexpr uint32_t kTileRowsOfCPerThread = 4; - static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( - [&](auto c_row) { - // For a warp tile of [16x16x32], take thread 0 as an example. - // Its VGPR[0] stores the value from C_tile[0,0], VGPR[1] stores - // C_tile[1,0], VGPR[2] stores C_tile[2,0], and VGPR[3] stores - // C_tile[3,0]. This means VGPR[0] should be multiplied by - // AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0], VGPR[2] by - // AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0]. + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + // For a warp tile of [16x16x32], take thread 0 as an + // example. Its VGPR[0] stores the value from C_tile[0,0], + // VGPR[1] stores C_tile[1,0], VGPR[2] stores C_tile[2,0], + // and VGPR[3] stores C_tile[3,0]. This means VGPR[0] should + // be multiplied by AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0], + // VGPR[2] by AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0]. - // Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, 0] - // from thread 1, ..., and AQ_tile[3, 0] from thread 3. - decltype(threadIdx.x) pull_from_lane = 0; - if constexpr(WarpGemm::kM == 16) - { - pull_from_lane = (__lane_id() / Traits::WarpGemm::kN * - kTileRowsOfCPerThread + - c_row) * - Traits::QScalesPerBlockRow + - kQScale; - } - else if constexpr(WarpGemm::kM == 32) - { - pull_from_lane = (__lane_id() / Traits::WarpGemm::kN * - kTileRowsOfCPerThread + - ((c_row >> 2) << 3) + (c_row & 0b11)) * - Traits::QScalesPerBlockRow + - kQScale; - } - else - { - static_assert(false, "WarpGemm::kM is not 16 nor 32."); - } - auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter]; + // Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1, + // 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3. + decltype(threadIdx.x) pull_from_lane = 0; + if constexpr(WarpGemm::kM == 16) + { + pull_from_lane = (__lane_id() / Traits::WarpGemm::kN * + kTileRowsOfCPerThread + + c_row) * + Traits::QScalesPerBlockRow + + kQScale; + } + else if constexpr(WarpGemm::kM == 32) + { + pull_from_lane = + (__lane_id() / Traits::WarpGemm::kN * + kTileRowsOfCPerThread + + ((c_row >> 2) << 3) + (c_row & 0b11)) * + Traits::QScalesPerBlockRow + + kQScale; + } + else + { + static_assert(false, "WarpGemm::kM is not 16 nor 32."); + } + auto& scale_reg = + aq_block_tensor.get_thread_buffer()[mIter]; - // cross lane ops + // cross lane ops + uint32_t scale_reg_dword; + + if constexpr(std::is_same_v) + { + scale_reg_dword = + ck_tile::bit_cast(scale_reg); + } + else + { + scale_reg_dword = static_cast(scale_reg); + } + + int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( + pull_from_lane << 2, + __builtin_bit_cast(int, scale_reg_dword)); + + float scale_reg_f = + Base::cvt_scale_to_fp32(gathered_scale_reg); + + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * + scale_reg_f * kA_cvt_scale * kB_cvt_scale); + }); + } + } + else + { + if(Traits::TransposeC) // transposed C + { + constexpr index_t reg_offset = mIter * Traits::AQPerBlock + kQScale; + constexpr auto tbuf_offset = number< + typename CBlockTensor::ThreadTensorDesc{}.calculate_offset( + merge_sequences(sequence{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + auto& scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset]; + float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg); + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * + scale_reg_f * kA_cvt_scale * kB_cvt_scale); + }); + } + else + { + + // Need to multiply aquant with accumulated C + // + // The accumulated C tile has the standard distribution. For example + // lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], + // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0], + // [26,0], [27,0]. + // + // These elements are in different rows, need to get the scale value + // for the corresponding row. + // Based on aquant's tile distribution, it can be inferred which + // lane holds the relevant scale. For example, the scales + // corresponding to the 16 elements held by lane 0 are held by lanes + // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 + // respectively. + // + // These scales can be obtained using __builtin_amdgcn_ds_bpermute. + + // MIters per warp + constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM; + + // Reg block offset based on mIter + constexpr index_t reg_block_offset = + ((mIter / mIters_per_warp) * Traits::AQPerBlock); + + constexpr index_t lane_base_offset = + (mIter % mIters_per_warp) * WarpGemm::kM; + + // Scale tensor offset along K + constexpr index_t src_reg_offset = reg_block_offset + kQScale; + + constexpr uint32_t kTileRows = 4; + constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; + + constexpr auto tbuf_offset = number< + typename CBlockTensor::ThreadTensorDesc{}.calculate_offset( + merge_sequences(sequence{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) { + // Multiply by 4 because output is stored in tiles of 4 + // x CNLane + constexpr uint32_t row_base = + ((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) + + ((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane); + + constexpr uint32_t reg_offset_for_row_data = + c_row / WarpGemm::kCMLane; + + // Lane index to source scale from + uint32_t src_lane_idx = + lane_base_offset + row_base + + (__lane_id() / WarpGemm::kN * kTileRows); + + // Directly index into thread buffer corresponding to + // desired row coefficient + auto& scale_reg = + aq_block_tensor.get_thread_buffer()[src_reg_offset]; uint32_t scale_reg_dword; if constexpr(std::is_same_v) @@ -427,97 +546,19 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase scale_reg_dword = static_cast(scale_reg); } + // Pull scale data across lanes int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( - pull_from_lane << 2, - __builtin_bit_cast(int, scale_reg_dword)); + src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword)); float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f * - kA_cvt_scale * kB_cvt_scale); + c_block_tensor.get_thread_buffer()[tbuf_offset + + reg_offset_for_row_data] += + (c_warp_tensor + .get_thread_buffer()[reg_offset_for_row_data] * + scale_reg_f * kA_cvt_scale * kB_cvt_scale); }); - } - else - { - // Need to multiply aquant with accumulated C - // - // The accumulated C tile has the standard distribution. For example - // lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], - // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0], - // [26,0], [27,0]. - // - // These elements are in different rows, need to get the scale value - // for the corresponding row. - // Based on aquant's tile distribution, it can be inferred which - // lane holds the relevant scale. For example, the scales corresponding - // to the 16 elements held by lane 0 are held by lanes 0, 1, 2, 3, 8, 9, - // 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 respectively. - // - // These scales can be obtained using __builtin_amdgcn_ds_bpermute. - - // MIters per warp - constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM; - - // Reg block offset based on mIter - constexpr index_t reg_block_offset = - ((mIter / mIters_per_warp) * Traits::AQPerBlock); - - constexpr index_t lane_base_offset = - (mIter % mIters_per_warp) * WarpGemm::kM; - - // Scale tensor offset along K - constexpr index_t src_reg_offset = reg_block_offset + kQScale; - - constexpr uint32_t kTileRows = 4; - constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; - - constexpr auto tbuf_offset = - number{}, - c_warp_y_index_zeros)) / - CBlockTensor::PackedSize>{}; - - static_for<0, WarpGemm::kM, WarpGemm::kCMLane>{}([&](auto c_row) { - // Multiply by 4 because output is stored in tiles of 4 - // x CNLane - constexpr uint32_t row_base = - ((c_row / kTiledCMsPerWarp) * kTiledCMsPerWarp) + - ((c_row % kTiledCMsPerWarp) / WarpGemm::kCMLane); - - constexpr uint32_t reg_offset_for_row_data = - c_row / WarpGemm::kCMLane; - - // Lane index to source scale from - uint32_t src_lane_idx = lane_base_offset + row_base + - (__lane_id() / WarpGemm::kN * kTileRows); - - // Directly index into thread buffer corresponding to - // desired row coefficient - auto& scale_reg = - aq_block_tensor.get_thread_buffer()[src_reg_offset]; - uint32_t scale_reg_dword; - - if constexpr(std::is_same_v) - { - scale_reg_dword = ck_tile::bit_cast(scale_reg); - } - else - { - scale_reg_dword = static_cast(scale_reg); - } - - // Pull scale data across lanes - int gathered_scale_reg = __builtin_amdgcn_ds_bpermute( - src_lane_idx * 4, __builtin_bit_cast(int, scale_reg_dword)); - - float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); - - c_block_tensor - .get_thread_buffer()[tbuf_offset + reg_offset_for_row_data] += - (c_warp_tensor.get_thread_buffer()[reg_offset_for_row_data] * - scale_reg_f * kA_cvt_scale * kB_cvt_scale); - }); + } } }); }); diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index c1fdeefc0c..5f15a15a45 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -50,7 +50,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2), - false>; + Problem::TransposeC>; static_assert(std::is_same_v); if constexpr(PreshuffleQuant) @@ -70,16 +70,30 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC } else { - using TileEncodingPattern = TileDistributionEncodingPatternAQ; + if constexpr(Problem::TransposeC) + { + using TileEncodingPatternTransposeC = + TileDistributionEncodingPatternAQTransposedC; + return TileEncodingPatternTransposeC::Make2DStaticTileDistribution(); + } + else + { + using TileEncodingPattern = TileDistributionEncodingPatternAQ; - return TileEncodingPattern::Make2DStaticTileDistribution(); + return TileEncodingPattern::Make2DStaticTileDistribution(); + } } } @@ -98,7 +112,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2), - false>; + Problem::TransposeC>; static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v); diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp index 4cca30fd3b..dfad7ba83d 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp @@ -18,6 +18,7 @@ template +struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEncodingPattern +{ + // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! + static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t num_warps = BlockSize / get_warp_size(); + + static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(number<0>{}); + static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(number<1>{}); + static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(number<2>{}); + + static constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarps * WarpGemm::kM); + + static_assert(num_warps == MWarps * NWarps * KWarps); + + // KWarps > 1 isn't supported + static_assert(KWarps == 1); + + // # of elements per thread + static constexpr index_t X = XPerTile; + static constexpr index_t XR = 2; + + // Number of iters per warp + // MIters are indexed using (Y0, Y1) + static constexpr index_t Y0 = MIterPerWarp; + + // # of warps in Y dim + static constexpr index_t Y1 = MWarps; + + static constexpr index_t Y2 = WarpGemm::kM; + + static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); + + CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 2>, + sequence<0, 0>>{}); + } +}; + } // namespace ck_tile diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc index 3439309857..9ed42ff8d2 100644 --- a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc +++ b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc @@ -90,6 +90,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s CodegenGemmShape, CodegenGemmTraits, QuantGroupSize, + transposed_warp_gemm, ComputeDataType, ck_tile::GemmPipelineScheduler::Intrawave, has_hot_loop_v,