fixing the tile window

This commit is contained in:
khushbu
2025-12-11 14:34:57 -05:00
parent 341d0e31b3
commit 92cbe3c17d

View File

@@ -338,9 +338,9 @@ struct QuantGemmKernel
// This separates the work into tiles that can be processed by
// individual warps/waves
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 4
const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : (QN_B/WarpTileN)) * KPerBlockBQ; // 32/16 x 2 = 4
const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; // 32/16 x 2 = 4 = 2
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 ==2
if(get_block_id() == 0 && get_thread_id() == 0)
{
@@ -358,7 +358,7 @@ struct QuantGemmKernel
bq_pad0_desc,
make_tuple(
make_pass_through_transform(bq_y),
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), // 2, 4, 4
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), // 2, 2, 2
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}));
@@ -369,8 +369,8 @@ struct QuantGemmKernel
const auto bq_pad1_desc = transform_tensor_descriptor(
bq_unmerge_pad0_desc,
make_tuple(
make_pass_through_transform(bq_y), // 2
make_pass_through_transform(wave_tile_count_x), // 1
make_pass_through_transform(bq_y), // 2, 2
make_pass_through_transform(wave_tile_count_x), // 1, 2
make_right_pad_transform(wave_tile_size,
get_padding_size(wave_tile_size, get_warp_size()))), // 64
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
@@ -1212,7 +1212,7 @@ struct QuantGemmKernel
// bq_block_window.template print_tile_window_range<BQDataType>(
// 0, 128, 0, 2, "bq block window");
bq_block_window.template print_tile_window_range<BQDataType>(
0, 3, 0, 64, "bq block window");
0, 8, 0, 64, "bq block window");
}
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);