From 3ea3ca7b36b4e33a86422d6ef0ccbdaac242a818 Mon Sep 17 00:00:00 2001 From: khuagarw Date: Sat, 6 Dec 2025 08:57:22 +0000 Subject: [PATCH] debugging --- .../run_gemm_quant_example.inc | 25 +++---- include/ck_tile/host/tensor_shuffle_utils.hpp | 32 ++++----- ...ock_universal_gemm_ar_flatbr_bquant_cr.hpp | 12 ++++ .../block_universal_gemm_as_bs_bquant_cr.hpp | 17 +++++ .../gemm_quant/kernel/gemm_quant_kernel.hpp | 42 ++++++----- .../gemm_bquant_pipeline_ag_bg_cr_policy.hpp | 14 ++-- .../pipeline/gemm_group_quant_utils.hpp | 72 ++++++++++++------- 7 files changed, 139 insertions(+), 75 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 9f413a9d75..b0a0d3fee7 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -677,18 +677,19 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::HostTensor bq_shuffle_host = ck_tile::shuffle_bq(bq_tensor_ptr.get(), GemmConfig::K_Tile / QuantGroupSize::kK); - printf("Preshuffle BQ tensor\n"); - for(int i = 0; i < static_cast(bq_shuffle_host.get_lengths()[0]); i++) - { - for(int j = 0; j < static_cast(bq_shuffle_host.get_lengths()[1]); j++) - { - for(int k = 0; k < static_cast(bq_shuffle_host.get_lengths()[2]); k++) - { - printf( - "bq_shuffle_host[%d][%d][%d]: %f\n", i, j, k, bq_shuffle_host(i, j, k)); - } - } - } + // printf("Preshuffle BQ tensor\n"); + // for(int i = 0; i < static_cast(bq_shuffle_host.get_lengths()[0]); i++) + // { + // for(int j = 0; j < static_cast(bq_shuffle_host.get_lengths()[1]); j++) + // { + // for(int k = 0; k < static_cast(bq_shuffle_host.get_lengths()[2]); k++) + // { + // printf( + // "bq_shuffle_host[%d][%d][%d]: %f\n", i, j, k, bq_shuffle_host(i, j, + // k)); + // } + // } + // } bq_dev_buf_ptr->ToDevice(bq_shuffle_host.data()); } else diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index 6524292bfd..c0c6f792c7 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -63,22 +63,22 @@ auto shuffle_bq(const ck_tile::HostTensor* t, int block_bq_k) int n_ = lengths[1]; ck_tile::HostTensor t_view({n_, bqk_dim / block_bq_k, block_bq_k}); std::copy(t->begin(), t->end(), t_view.begin()); - printf("I am inside shuffle_bq\n"); - printf("t_view.get_lengths(): %lu, %lu, %lu\n", - t_view.get_lengths()[0], - t_view.get_lengths()[1], - t_view.get_lengths()[2]); - for(int i = 0; i < static_cast(t_view.get_lengths()[0]); i++) - { - for(int j = 0; j < static_cast(t_view.get_lengths()[1]); j++) - { - for(int k = 0; k < static_cast(t_view.get_lengths()[2]); k++) - { - printf("t_view[%d][%d][%d]: %f\n", i, j, k, t_view(i, j, k)); - } - } - } - printf("I am inside shuffle_bq\n"); + // printf("I am inside shuffle_bq\n"); + // printf("t_view.get_lengths(): %lu, %lu, %lu\n", + // t_view.get_lengths()[0], + // t_view.get_lengths()[1], + // t_view.get_lengths()[2]); + // for(int i = 0; i < static_cast(t_view.get_lengths()[0]); i++) + // { + // for(int j = 0; j < static_cast(t_view.get_lengths()[1]); j++) + // { + // for(int k = 0; k < static_cast(t_view.get_lengths()[2]); k++) + // { + // printf("t_view[%d][%d][%d]: %f\n", i, j, k, t_view(i, j, k)); + // } + // } + // } + // printf("I am inside shuffle_bq\n"); return ck_tile::reference_permute(t_view, {1, 0, 2}); } } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index 58b713cb35..ea19239f33 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -195,6 +195,18 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword)); float scale_reg_f = cvt_scale_to_fp32(gathered_scale_reg); + printf("get_block_id(): %d, get_thread_id(): %d, nIter: %d, lane_id(): %d, " + "kQScale: %d, pull_from_lane: %d, scale_reg: %f, " + "gathered_scale_reg: %d, scale_reg_f: %f\n", + get_block_id(), + get_thread_id(), + nIter, + __lane_id(), + static_cast(kQScale), + pull_from_lane, + scale_reg, + gathered_scale_reg, + scale_reg_f); static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index d97145cbc3..ac9e00ba84 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -355,6 +355,23 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg); + // if(get_block_id() ==0 && get_thread_id() == 0){ + printf("get_block_id(): %d, get_thread_id(): %d, nIter: %d, lane_id(): " + "%u, KQPerBLock: %d, " + "kQScale: %d, pull_from_lane: %u, scale_reg: %f, " + "gathered_scale_reg: %d, scale_reg_f: %f\n", + get_block_id(), + get_thread_id(), + static_cast(nIter), + __lane_id(), + Traits::KQPerBlock, + static_cast(kQScale), + pull_from_lane, + scale_reg, + gathered_scale_reg, + scale_reg_f); + //} + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( [&](auto c_row) { c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 5a413108e5..0ded65ce2e 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -280,12 +280,13 @@ struct QuantGemmKernel // Helper: Create Pre-shuffled Quantization Tensor Descriptor // =================================================================== template CK_TILE_DEVICE static auto - MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B) + MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B) { // Step 1: Calculate base BQ tensor dimensions // ---------------------------------------------------------- @@ -316,7 +317,7 @@ struct QuantGemmKernel // ---------------------------------------------------------- // Pad the X dimension to be a multiple of block_tile_size to ensure // each thread block can process complete tiles without edge cases - const auto block_tile_size = NPerBlock * KPerBlockBQ; // 64 * 2 =128 + const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; // 64 * 2 = 128 || 8 * 2 = 16 if(get_block_id() == 0 && get_thread_id() == 0) { @@ -327,7 +328,7 @@ struct QuantGemmKernel bq_desc, make_tuple( make_pass_through_transform(bq_y), - make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), // 2, 128 + make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), // 2, 16 make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -336,22 +337,28 @@ struct QuantGemmKernel // Split the X dimension into [wave_tile_count_x, wave_tile_size] // This separates the work into tiles that can be processed by // individual warps/waves - const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; - const auto wave_tile_size = WarpTileN * KPerBlockBQ; - const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); + const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 128 || 16 + const auto wave_tile_size = (WarpTileN / QN_B) * KPerBlockBQ; // 16 * 2= 32 || 16/8 x 2 = 4 + const auto wave_tile_count_x = + ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 128/32 = 4 || 16/4 = 4 if(get_block_id() == 0 && get_thread_id() == 0) { - printf("pad_bq_x:%d, wave_tile_size: %d, wave_tile_count_x: %d\n", + printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size: " + "%d, wave_tile_count_x: %d\n", pad_bq_x, + WarpTileN, + NPerBlockBQ, + KPerBlockBQ, wave_tile_size, wave_tile_count_x); } const auto bq_unmerge_pad0_desc = transform_tensor_descriptor( bq_pad0_desc, - make_tuple(make_pass_through_transform(bq_y), - make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), + make_tuple( + make_pass_through_transform(bq_y), + make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), // 2, 4, 4 make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1, 2>{})); @@ -361,10 +368,11 @@ struct QuantGemmKernel // This ensures coalesced memory accesses within each warp const auto bq_pad1_desc = transform_tensor_descriptor( bq_unmerge_pad0_desc, - make_tuple(make_pass_through_transform(bq_y), - make_pass_through_transform(wave_tile_count_x), - make_right_pad_transform(wave_tile_size, - get_padding_size(wave_tile_size, get_warp_size()))), + make_tuple( + make_pass_through_transform(bq_y), // 2 + make_pass_through_transform(wave_tile_count_x), // 4 + make_right_pad_transform(wave_tile_size, + get_padding_size(wave_tile_size, get_warp_size()))), // 64 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); @@ -381,8 +389,8 @@ struct QuantGemmKernel } const auto bq_merge_pad1_desc = transform_tensor_descriptor( bq_pad1_desc, - make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), - make_pass_through_transform(pad_wave_size)), + make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 8 + make_pass_through_transform(pad_wave_size)), // 64 make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{})); @@ -817,11 +825,13 @@ struct QuantGemmKernel using QuantGroupSize = remove_cvref_t; return MakePreshuffledQuantTensorView< GemmPipeline::KPerBlockBQ, + GemmPipeline::NPerBlockBQ, GemmPipeline::NPerBlock, TilePartitioner::BlockGemmShape::WarpTile::at(I1), GemmPipeline::GetVectorSizeBQ()>( bq_ptr, ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN), + QuantGroupSize::kN, kargs.QK_B); } else @@ -1170,7 +1180,7 @@ struct QuantGemmKernel // bq_block_window.template print_tile_window_range( // 0, 128, 0, 2, "bq block window"); bq_block_window.template print_tile_window_range( - 0, 8, 0, 64, "bq block window"); + 0, 1, 0, 64, "bq block window"); } return GemmPipeline{}.template operator()( a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 870326cb9d..a09deabab7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -35,12 +35,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC using BQLayout = remove_cvref_t; using BlockGemmShape = typename Problem::BlockGemmShape; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; - constexpr index_t VecLoadSize = GetVectorSizeBQ(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK; + // constexpr index_t VecLoadSize = GetVectorSizeBQ(); constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; using WarpTile = typename Problem::BlockGemmShape::WarpTile; @@ -61,7 +61,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC BlockSize, NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), - VecLoadSize, + Problem::QuantGroupSize::kN, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index b51dee752d..cb3333ba58 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -189,8 +189,8 @@ struct tile_distribution_encoding_pattern_aq_transposed_c template struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern @@ -236,41 +236,65 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding { if constexpr(PreshuffleQuant) { - constexpr index_t X1 = warp_size; - constexpr index_t X0 = XPerTile / warp_size; - constexpr index_t Y1 = NWarps; - constexpr index_t Y0 = YPerTile / Y1; + // constexpr index_t X1 = warp_size; + // constexpr index_t X0 = XPerTile / warp_size; + // constexpr index_t Y1 = NWarps; + // constexpr index_t Y0 = YPerTile / Y1; + + // return make_static_tile_distribution( + // tile_distribution_encoding, + // tuple, sequence>, + // tuple, sequence<2>>, // (Mwarp, Nwarp), + // (X1 = warp_size(64)) tuple, + // sequence<1>>, sequence<1, 2>, //(NiterPerWarp, + // X(threads in x dimension, 1)) sequence<0, 0>>{}); + + // constexpr index_t X1 = warp_size; //64 + constexpr index_t X0 = XPerTile / warp_size; // 64/64 = 1 + constexpr index_t X1 = XPerTile / WarpGemm::kN; // 64/16 = 4 + constexpr index_t X2 = WarpGemm::kN / YPerQ; // 16/8=2 + constexpr index_t XR = YPerQ; // 8 + constexpr index_t Y1 = NWarps; // 4 + constexpr index_t Y0 = YPerTile / Y1; // 1 + constexpr index_t YR = 1; return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2>>, - tuple, sequence<1>>, - sequence<1, 2>, - sequence<0, 0>>{}); + tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple, sequence<0, 2, 0>>, // (Mwarp, Nwarp), + tuple, + sequence<1, 2, 2>>, //(repeat for 8 threads in X direction, X2(no of + // scales per warp), X1(warp_size/quant_group_size), + // YR)(8, 2, 4, 1) + sequence<1, 2>, + sequence<0, 0>>{}); } else { if constexpr(YPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t X = XPerTile; // Full X dimension of tile - constexpr index_t XR = 1; // No Y replication needed - constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t Y1 = NWarps; // Number of warps in N-dim - constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp - constexpr index_t YR = YPerQ; // Elements per quantization group + constexpr index_t X = XPerTile; // Full X dimension of tile + constexpr index_t XR = 1; // No Y replication needed + constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t Y1 = NWarps; // Number of warps in N-dim + constexpr index_t Y2 = + WarpGemm::kN / YPerQ; // Number of scales per warp 16/ 8 = 2 + constexpr index_t YR = YPerQ; // Elements per quantization group 8 static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 1, 0>>, - tuple, sequence<1, 2, 2>>, - sequence<1, 2>, - sequence<0, 0>>{}); + tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple, sequence<0, 1, 0>>, //(Mwarp, Nwarp), (XR, Y2[no of + // scales per warp], YR) + tuple, sequence<1, 2, 2>>, + sequence<1, 2>, //(NiterPerWarp, X(threads in x dimension)) + sequence<0, 0>>{}); } else if constexpr(YPerQ <= WarpGemm::kN * NWarps) {