diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 5965009b3b..bc6c970411 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -5,7 +5,7 @@ #include #include #include "ck_tile/host/permute_pk_int4.hpp" -#include "ck_tile/host/shuffle_utils.hpp" +#include "ck_tile/host/tensor_shuffle_utils.hpp" template static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as corresponding " "C block tensor data type!"); + constexpr auto warp_size = get_warp_size(); // hot loop: static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -365,10 +366,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; float scale_reg_f = Base::cvt_scale_to_fp32(scale_reg); - static_for<0, WarpGemm::kM / 2, 1>{}([&](auto c_row) { - c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += - (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); - }); + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); }); }); }); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index bc2c9c603a..6f049a20a7 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -686,8 +686,8 @@ struct QuantGemmKernel static_assert(std::is_same_v); return make_naive_tensor_view( bq_ptr, - make_tuple(kargs.N, kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), + make_tuple(kargs.QK_B, kargs.N), + make_tuple(1, kargs.stride_BQ), number{}, number<1>{}); } @@ -905,9 +905,9 @@ struct QuantGemmKernel static_assert(std::is_same_v); return make_tile_window( bq_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + make_tuple(number{}, + number{}), + {0, i_n}); } else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index f9278bf985..54bca21501 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -52,8 +52,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC using TileEncodingPattern = tile_distribution_encoding_pattern_bq; return TileEncodingPattern::make_2d_static_tile_distribution(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index c27fbf5b50..92b1316b34 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -254,8 +254,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV constexpr bool is_b_row_major = std::is_same_v; static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); - static_assert(NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && - KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], + static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}], "Bq block window has incorrect lengths for defined BqLayout!"); static_assert(is_a_col_major @@ -313,7 +313,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV constexpr BDramTileWindowStep b_dram_tile_window_step = is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); constexpr BQDramTileWindowStep bq_dram_tile_window_step = - is_bq_col_major ? make_array(0, KPerBlockBQ) : make_array(KPerBlockBQ, 0); + is_bq_col_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); @@ -358,6 +358,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV if constexpr(HasHotLoop) { + constexpr index_t tail_count = + ((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2; index_t i = 0; do { @@ -403,7 +405,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV __builtin_amdgcn_sched_barrier(0); i += 1; - } while(i < (num_loop - 1)); + } while(i < (num_loop - tail_count)); } // tail if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 54b64c34be..c13da6206f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -191,28 +191,28 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding static_assert(KWarps == 1); // # of elements per thread - static constexpr index_t X = XPerTile; - static constexpr index_t XR = 2; + static constexpr index_t Y = YPerTile; + static constexpr index_t YR = 1; // Number of iters per warp // MIters are indexed using (Y0, Y1) - static constexpr index_t Y0 = NIterPerWarp; + static constexpr index_t X0 = NIterPerWarp; // # of warps in Y dim - static constexpr index_t Y1 = NWarps; + static constexpr index_t X1 = NWarps; - static constexpr index_t Y2 = WarpGemm::kN; + static constexpr index_t X2 = WarpGemm::kN; - static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); + static_assert(X0 * X1 * X2 == XPerTile, "X0, X1, X2 must cover the blocktile along Y."); CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 1>>, + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2>>, tuple, sequence<1, 2>>, - sequence<1, 2>, + sequence<2, 1>, sequence<0, 0>>{}); } }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 01c1a72335..196f47badb 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -236,7 +236,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV // BQ DRAM window for load auto bq_copy_dram_window = make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), bq_dram_block_window_tmp.get_window_origin(), PipelinePolicy::template MakeBQDramTileDistribution()); @@ -269,7 +269,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV BQBlockTile bq_block_tile, bq_block_tile_2; bq_block_tile = load_tile(bq_copy_dram_window); // move BQ to tile 1 - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); // Prefill A0 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); @@ -318,7 +318,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); bq_block_tile_2 = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); // Prefill A(2i+1) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); @@ -360,7 +360,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); // Prefill A(2i+2) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index b12259c773..0f7fdcdbc7 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -5,7 +5,7 @@ #include "test_gemm_quant_base.hpp" #include "ck_tile/host/permute_pk_int4.hpp" -#include "ck_tile/host/shuffle_utils.hpp" +#include "ck_tile/host/tensor_shuffle_utils.hpp" struct GemmConfigBase {