working prefill shapes

This commit is contained in:
khushbu
2025-12-16 18:21:44 -05:00
parent 5a3e7de060
commit 373d89d381
4 changed files with 63 additions and 55 deletions

View File

@@ -4,7 +4,8 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>; // GemmConfigPreshuffleQuantDecode<T>;
// //GemmConfigPreshuffleBQuantPrefill<T>;
void bquant_quantgrouped_preshufflequant_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)

View File

@@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[])
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"bf8i4 or bf16fp4")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("warmup", "1", "Number of iterations before benchmarking the kernel")
.insert("repeat", "0", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "SplitK value")
.insert("device", "0", "Device id that will be used to run the kernel")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "Flush cache before running the kernel")
.insert("rotating_count", "1000", "Rotating count")
.insert("rotating_count", "0", "Rotating count")
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")

View File

@@ -298,15 +298,15 @@ struct QuantGemmKernel
const auto bq_x = N * KPerBlockBQ; // 2x2 = 4
const auto bq_y = QK_B / KPerBlockBQ; // 4/2 = 2
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("N:%d, QK_B:%d\n", N, QK_B);
// printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n",
// bq_x,
// bq_y,
// GetVectorSizeBQ,
// KPerBlockBQ);
// }
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("N:%d, QK_B:%d\n", N, QK_B);
printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n",
bq_x,
bq_y,
GetVectorSizeBQ,
KPerBlockBQ);
}
const auto bq_desc = make_naive_tensor_descriptor(make_tuple(bq_y, bq_x),
make_tuple(bq_x, 1),
@@ -319,10 +319,10 @@ struct QuantGemmKernel
// each thread block can process complete tiles without edge cases
const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; // 2x2 = 4
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("block_tile_size:%d \n", block_tile_size);
// }
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("block_tile_size:%d \n", block_tile_size);
}
const auto bq_pad0_desc = transform_tensor_descriptor(
bq_desc,
@@ -344,18 +344,17 @@ struct QuantGemmKernel
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 ==2
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:
// "
// "%d, wave_tile_count_x: %d\n",
// pad_bq_x,
// WarpTileN,
// NPerBlockBQ,
// KPerBlockBQ,
// wave_tile_size,
// wave_tile_count_x);
// }
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:"
"%d, wave_tile_count_x: %d\n",
pad_bq_x,
WarpTileN,
NPerBlockBQ,
KPerBlockBQ,
wave_tile_size,
wave_tile_count_x);
}
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
bq_pad0_desc,
@@ -386,13 +385,11 @@ struct QuantGemmKernel
// where merged_outer_dim = bq_y * wave_tile_count_x
// This layout facilitates efficient block-to-data mapping
const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("pad_wave_size:%d\n", pad_wave_size);
// printf("Final bq tensor lengths: %d x %d \n",
// bq_y * wave_tile_count_x,
// pad_wave_size);
// }
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("pad_wave_size:%d\n", pad_wave_size);
printf("Final bq tensor lengths: %d x %d \n", bq_y * wave_tile_count_x, pad_wave_size);
}
const auto bq_merge_pad1_desc = transform_tensor_descriptor(
bq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 4
@@ -1123,31 +1120,35 @@ struct QuantGemmKernel
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
constexpr auto block_n =
TilePartitioner::NPerBlock / QuantGroupSize::kN; // 64 / 32 = 2
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto warpPerGroup = (QuantGroupSize::kN < warp_n)
? (warp_n / QuantGroupSize::kN)
: (QuantGroupSize::kN / warp_n);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width = ck_tile::integer_least_multiple(
warp_n * bqk_per_block, get_warp_size()); // 128
constexpr auto tile_window_height =
min(block_n,
TilePartitioner::BlockGemmShape::BlockWarps::at(
I1)); // block_n / warp_n; // 2 / 4 = 0
(block_n > warpPerGroup) ? block_n / warpPerGroup : block_n;
auto block_n_idx = i_n / TilePartitioner::NPerBlock; // 0,1,2
// if(get_thread_id() == 0)
// {
// printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
// printf("block_id: %d, block_n: %d, warp_n: %d, bqk_per_block: %d,
// block_n_idx: %d, "
// "tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
// get_block_id(),
// static_cast<int>(block_n),
// static_cast<int>(warp_n),
// static_cast<int>(bqk_per_block),
// static_cast<int>(block_n_idx),
// tile_window_width,
// static_cast<int>(tile_window_height),
// static_cast<int>(i_n));
// }
if(get_thread_id() == 0)
{
printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
printf("block_id: %d, block_n: %d, warp_n: %d, warpPerGroup: %d, "
"bqk_per_block: %d, block_n_idx: %d, "
"tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
get_block_id(),
static_cast<int>(block_n),
static_cast<int>(warp_n),
static_cast<int>(warpPerGroup),
static_cast<int>(bqk_per_block),
static_cast<int>(block_n_idx),
tile_window_width,
static_cast<int>(tile_window_height),
static_cast<int>(i_n));
}
return make_tile_window(
bq_pad_view,

View File

@@ -308,6 +308,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
0)
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("bq_dram_tile_window_step: %d, %d\n",
bq_dram_tile_window_step[I0{}],
bq_dram_tile_window_step[I1{}]);
}
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);