From 373d89d381eeec06ed18b6fa3a024285c5cb27e3 Mon Sep 17 00:00:00 2001 From: khushbu Date: Tue, 16 Dec 2025 18:21:44 -0500 Subject: [PATCH] working prefill shapes --- ...mm_bquant_quantgrouped_preshufflequant.cpp | 3 +- .../38_block_scale_gemm/gemm_quant.cpp | 6 +- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 103 +++++++++--------- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 6 + 4 files changed, 63 insertions(+), 55 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index 710ca0e8e1..2750112683 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -4,7 +4,8 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigPreshuffleBQuantPrefill; +using GemmConfig = GemmConfigPreshuffleBQuantPrefill; // GemmConfigPreshuffleQuantDecode; + // //GemmConfigPreshuffleBQuantPrefill; void bquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 9c8d8eba50..de9d691a01 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[]) "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " "bf8i4 or bf16fp4") - .insert("warmup", "50", "Number of iterations before benchmarking the kernel") - .insert("repeat", "1000", "Number of iterations to benchmark the kernel") + .insert("warmup", "1", "Number of iterations before benchmarking the kernel") + .insert("repeat", "0", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "SplitK value") .insert("device", "0", "Device id that will be used to run the kernel") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("flush_cache", "true", "Flush cache before running the kernel") - .insert("rotating_count", "1000", "Rotating count") + .insert("rotating_count", "0", "Rotating count") .insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol") .insert("preshuffleb", "false", "Enable preshuffle of tensor B") .insert("preshufflequant", "false", "Enable preshuffle of quant tensor") diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 8e16ad46ec..cb83d807f1 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -298,15 +298,15 @@ struct QuantGemmKernel const auto bq_x = N * KPerBlockBQ; // 2x2 = 4 const auto bq_y = QK_B / KPerBlockBQ; // 4/2 = 2 - // if(get_block_id() == 0 && get_thread_id() == 0) - // { - // printf("N:%d, QK_B:%d\n", N, QK_B); - // printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n", - // bq_x, - // bq_y, - // GetVectorSizeBQ, - // KPerBlockBQ); - // } + if(get_block_id() == 0 && get_thread_id() == 0) + { + printf("N:%d, QK_B:%d\n", N, QK_B); + printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n", + bq_x, + bq_y, + GetVectorSizeBQ, + KPerBlockBQ); + } const auto bq_desc = make_naive_tensor_descriptor(make_tuple(bq_y, bq_x), make_tuple(bq_x, 1), @@ -319,10 +319,10 @@ struct QuantGemmKernel // each thread block can process complete tiles without edge cases const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; // 2x2 = 4 - // if(get_block_id() == 0 && get_thread_id() == 0) - // { - // printf("block_tile_size:%d \n", block_tile_size); - // } + if(get_block_id() == 0 && get_thread_id() == 0) + { + printf("block_tile_size:%d \n", block_tile_size); + } const auto bq_pad0_desc = transform_tensor_descriptor( bq_desc, @@ -344,18 +344,17 @@ struct QuantGemmKernel const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 ==2 - // if(get_block_id() == 0 && get_thread_id() == 0) - // { - // printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size: - // " - // "%d, wave_tile_count_x: %d\n", - // pad_bq_x, - // WarpTileN, - // NPerBlockBQ, - // KPerBlockBQ, - // wave_tile_size, - // wave_tile_count_x); - // } + if(get_block_id() == 0 && get_thread_id() == 0) + { + printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:" + "%d, wave_tile_count_x: %d\n", + pad_bq_x, + WarpTileN, + NPerBlockBQ, + KPerBlockBQ, + wave_tile_size, + wave_tile_count_x); + } const auto bq_unmerge_pad0_desc = transform_tensor_descriptor( bq_pad0_desc, @@ -386,13 +385,11 @@ struct QuantGemmKernel // where merged_outer_dim = bq_y * wave_tile_count_x // This layout facilitates efficient block-to-data mapping const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); - // if(get_block_id() == 0 && get_thread_id() == 0) - // { - // printf("pad_wave_size:%d\n", pad_wave_size); - // printf("Final bq tensor lengths: %d x %d \n", - // bq_y * wave_tile_count_x, - // pad_wave_size); - // } + if(get_block_id() == 0 && get_thread_id() == 0) + { + printf("pad_wave_size:%d\n", pad_wave_size); + printf("Final bq tensor lengths: %d x %d \n", bq_y * wave_tile_count_x, pad_wave_size); + } const auto bq_merge_pad1_desc = transform_tensor_descriptor( bq_pad1_desc, make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 4 @@ -1123,31 +1120,35 @@ struct QuantGemmKernel static_assert(std::is_same_v); constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; // 64 / 32 = 2 - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); + + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); + constexpr auto warpPerGroup = (QuantGroupSize::kN < warp_n) + ? (warp_n / QuantGroupSize::kN) + : (QuantGroupSize::kN / warp_n); constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; constexpr auto tile_window_width = ck_tile::integer_least_multiple( warp_n * bqk_per_block, get_warp_size()); // 128 constexpr auto tile_window_height = - min(block_n, - TilePartitioner::BlockGemmShape::BlockWarps::at( - I1)); // block_n / warp_n; // 2 / 4 = 0 + (block_n > warpPerGroup) ? block_n / warpPerGroup : block_n; + auto block_n_idx = i_n / TilePartitioner::NPerBlock; // 0,1,2 - // if(get_thread_id() == 0) - // { - // printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n"); - // printf("block_id: %d, block_n: %d, warp_n: %d, bqk_per_block: %d, - // block_n_idx: %d, " - // "tile_window_width: %d, tile_window_height: %d, i_n: %d\n", - // get_block_id(), - // static_cast(block_n), - // static_cast(warp_n), - // static_cast(bqk_per_block), - // static_cast(block_n_idx), - // tile_window_width, - // static_cast(tile_window_height), - // static_cast(i_n)); - // } + if(get_thread_id() == 0) + { + printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n"); + printf("block_id: %d, block_n: %d, warp_n: %d, warpPerGroup: %d, " + "bqk_per_block: %d, block_n_idx: %d, " + "tile_window_width: %d, tile_window_height: %d, i_n: %d\n", + get_block_id(), + static_cast(block_n), + static_cast(warp_n), + static_cast(warpPerGroup), + static_cast(bqk_per_block), + static_cast(block_n_idx), + tile_window_width, + static_cast(tile_window_height), + static_cast(i_n)); + } return make_tile_window( bq_pad_view, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index dcc90b3be7..af2d9a4b3a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -308,6 +308,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3