mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
working prefill shapes
This commit is contained in:
@@ -4,7 +4,8 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>; // GemmConfigPreshuffleQuantDecode<T>;
|
||||
// //GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
|
||||
@@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[])
|
||||
"fp8",
|
||||
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
|
||||
"bf8i4 or bf16fp4")
|
||||
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
|
||||
.insert("warmup", "1", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "0", "Number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "SplitK value")
|
||||
.insert("device", "0", "Device id that will be used to run the kernel")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "Flush cache before running the kernel")
|
||||
.insert("rotating_count", "1000", "Rotating count")
|
||||
.insert("rotating_count", "0", "Rotating count")
|
||||
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
|
||||
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
|
||||
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")
|
||||
|
||||
@@ -298,15 +298,15 @@ struct QuantGemmKernel
|
||||
const auto bq_x = N * KPerBlockBQ; // 2x2 = 4
|
||||
const auto bq_y = QK_B / KPerBlockBQ; // 4/2 = 2
|
||||
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("N:%d, QK_B:%d\n", N, QK_B);
|
||||
// printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n",
|
||||
// bq_x,
|
||||
// bq_y,
|
||||
// GetVectorSizeBQ,
|
||||
// KPerBlockBQ);
|
||||
// }
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("N:%d, QK_B:%d\n", N, QK_B);
|
||||
printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n",
|
||||
bq_x,
|
||||
bq_y,
|
||||
GetVectorSizeBQ,
|
||||
KPerBlockBQ);
|
||||
}
|
||||
|
||||
const auto bq_desc = make_naive_tensor_descriptor(make_tuple(bq_y, bq_x),
|
||||
make_tuple(bq_x, 1),
|
||||
@@ -319,10 +319,10 @@ struct QuantGemmKernel
|
||||
// each thread block can process complete tiles without edge cases
|
||||
const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; // 2x2 = 4
|
||||
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("block_tile_size:%d \n", block_tile_size);
|
||||
// }
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("block_tile_size:%d \n", block_tile_size);
|
||||
}
|
||||
|
||||
const auto bq_pad0_desc = transform_tensor_descriptor(
|
||||
bq_desc,
|
||||
@@ -344,18 +344,17 @@ struct QuantGemmKernel
|
||||
const auto wave_tile_count_x =
|
||||
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 ==2
|
||||
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:
|
||||
// "
|
||||
// "%d, wave_tile_count_x: %d\n",
|
||||
// pad_bq_x,
|
||||
// WarpTileN,
|
||||
// NPerBlockBQ,
|
||||
// KPerBlockBQ,
|
||||
// wave_tile_size,
|
||||
// wave_tile_count_x);
|
||||
// }
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:"
|
||||
"%d, wave_tile_count_x: %d\n",
|
||||
pad_bq_x,
|
||||
WarpTileN,
|
||||
NPerBlockBQ,
|
||||
KPerBlockBQ,
|
||||
wave_tile_size,
|
||||
wave_tile_count_x);
|
||||
}
|
||||
|
||||
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
|
||||
bq_pad0_desc,
|
||||
@@ -386,13 +385,11 @@ struct QuantGemmKernel
|
||||
// where merged_outer_dim = bq_y * wave_tile_count_x
|
||||
// This layout facilitates efficient block-to-data mapping
|
||||
const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("pad_wave_size:%d\n", pad_wave_size);
|
||||
// printf("Final bq tensor lengths: %d x %d \n",
|
||||
// bq_y * wave_tile_count_x,
|
||||
// pad_wave_size);
|
||||
// }
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("pad_wave_size:%d\n", pad_wave_size);
|
||||
printf("Final bq tensor lengths: %d x %d \n", bq_y * wave_tile_count_x, pad_wave_size);
|
||||
}
|
||||
const auto bq_merge_pad1_desc = transform_tensor_descriptor(
|
||||
bq_pad1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 4
|
||||
@@ -1123,31 +1120,35 @@ struct QuantGemmKernel
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n =
|
||||
TilePartitioner::NPerBlock / QuantGroupSize::kN; // 64 / 32 = 2
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr auto warpPerGroup = (QuantGroupSize::kN < warp_n)
|
||||
? (warp_n / QuantGroupSize::kN)
|
||||
: (QuantGroupSize::kN / warp_n);
|
||||
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
|
||||
constexpr auto tile_window_width = ck_tile::integer_least_multiple(
|
||||
warp_n * bqk_per_block, get_warp_size()); // 128
|
||||
constexpr auto tile_window_height =
|
||||
min(block_n,
|
||||
TilePartitioner::BlockGemmShape::BlockWarps::at(
|
||||
I1)); // block_n / warp_n; // 2 / 4 = 0
|
||||
(block_n > warpPerGroup) ? block_n / warpPerGroup : block_n;
|
||||
|
||||
auto block_n_idx = i_n / TilePartitioner::NPerBlock; // 0,1,2
|
||||
|
||||
// if(get_thread_id() == 0)
|
||||
// {
|
||||
// printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
|
||||
// printf("block_id: %d, block_n: %d, warp_n: %d, bqk_per_block: %d,
|
||||
// block_n_idx: %d, "
|
||||
// "tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
|
||||
// get_block_id(),
|
||||
// static_cast<int>(block_n),
|
||||
// static_cast<int>(warp_n),
|
||||
// static_cast<int>(bqk_per_block),
|
||||
// static_cast<int>(block_n_idx),
|
||||
// tile_window_width,
|
||||
// static_cast<int>(tile_window_height),
|
||||
// static_cast<int>(i_n));
|
||||
// }
|
||||
if(get_thread_id() == 0)
|
||||
{
|
||||
printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
|
||||
printf("block_id: %d, block_n: %d, warp_n: %d, warpPerGroup: %d, "
|
||||
"bqk_per_block: %d, block_n_idx: %d, "
|
||||
"tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
|
||||
get_block_id(),
|
||||
static_cast<int>(block_n),
|
||||
static_cast<int>(warp_n),
|
||||
static_cast<int>(warpPerGroup),
|
||||
static_cast<int>(bqk_per_block),
|
||||
static_cast<int>(block_n_idx),
|
||||
tile_window_width,
|
||||
static_cast<int>(tile_window_height),
|
||||
static_cast<int>(i_n));
|
||||
}
|
||||
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
|
||||
@@ -308,6 +308,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
0)
|
||||
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
|
||||
: make_array(0, KPerBlockBQ);
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("bq_dram_tile_window_step: %d, %d\n",
|
||||
bq_dram_tile_window_step[I0{}],
|
||||
bq_dram_tile_window_step[I1{}]);
|
||||
}
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
Reference in New Issue
Block a user