mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 05:37:34 +00:00
fixing the tile window
This commit is contained in:
@@ -338,9 +338,9 @@ struct QuantGemmKernel
|
||||
// This separates the work into tiles that can be processed by
|
||||
// individual warps/waves
|
||||
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 4
|
||||
const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : (QN_B/WarpTileN)) * KPerBlockBQ; // 32/16 x 2 = 4
|
||||
const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; // 32/16 x 2 = 4 = 2
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1
|
||||
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 ==2
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
@@ -358,7 +358,7 @@ struct QuantGemmKernel
|
||||
bq_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(bq_y),
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), // 2, 4, 4
|
||||
make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), // 2, 2, 2
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}));
|
||||
|
||||
@@ -369,8 +369,8 @@ struct QuantGemmKernel
|
||||
const auto bq_pad1_desc = transform_tensor_descriptor(
|
||||
bq_unmerge_pad0_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(bq_y), // 2
|
||||
make_pass_through_transform(wave_tile_count_x), // 1
|
||||
make_pass_through_transform(bq_y), // 2, 2
|
||||
make_pass_through_transform(wave_tile_count_x), // 1, 2
|
||||
make_right_pad_transform(wave_tile_size,
|
||||
get_padding_size(wave_tile_size, get_warp_size()))), // 64
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
@@ -1212,7 +1212,7 @@ struct QuantGemmKernel
|
||||
// bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
// 0, 128, 0, 2, "bq block window");
|
||||
bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
0, 3, 0, 64, "bq block window");
|
||||
0, 8, 0, 64, "bq block window");
|
||||
}
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
|
||||
|
||||
Reference in New Issue
Block a user