From 92cbe3c17d089b52aedd9deca938194e4db1bd1b Mon Sep 17 00:00:00 2001 From: khushbu Date: Thu, 11 Dec 2025 14:34:57 -0500 Subject: [PATCH] fixing the tile window --- .../ops/gemm_quant/kernel/gemm_quant_kernel.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index f633f58802..bcce140293 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -338,9 +338,9 @@ struct QuantGemmKernel // This separates the work into tiles that can be processed by // individual warps/waves const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 4 - const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : (QN_B/WarpTileN)) * KPerBlockBQ; // 32/16 x 2 = 4 + const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; // 32/16 x 2 = 4 = 2 const auto wave_tile_count_x = - ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 + ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 ==2 if(get_block_id() == 0 && get_thread_id() == 0) { @@ -358,7 +358,7 @@ struct QuantGemmKernel bq_pad0_desc, make_tuple( make_pass_through_transform(bq_y), - make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), // 2, 4, 4 + make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), // 2, 2, 2 make_tuple(sequence<0>{}, sequence<1>{}), make_tuple(sequence<0>{}, sequence<1, 2>{})); @@ -369,8 +369,8 @@ struct QuantGemmKernel const auto bq_pad1_desc = transform_tensor_descriptor( bq_unmerge_pad0_desc, make_tuple( - make_pass_through_transform(bq_y), // 2 - make_pass_through_transform(wave_tile_count_x), // 1 + make_pass_through_transform(bq_y), // 2, 2 + make_pass_through_transform(wave_tile_count_x), // 1, 2 make_right_pad_transform(wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))), // 64 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), @@ -1212,7 +1212,7 @@ struct QuantGemmKernel // bq_block_window.template print_tile_window_range( // 0, 128, 0, 2, "bq block window"); bq_block_window.template print_tile_window_range( - 0, 3, 0, 64, "bq block window"); + 0, 8, 0, 64, "bq block window"); } return GemmPipeline{}.template operator()( a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);