formatting

This commit is contained in:
khushbu
2025-12-11 22:29:23 -05:00
parent 995d1a5cf6
commit 44aaaacbec
8 changed files with 184 additions and 124 deletions

View File

@@ -16,7 +16,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# gemm_aquant_quantgrouped_preshufflequant.cpp
# gemm_bquant_quantgrouped_bf8i4.cpp
# gemm_bquant_quantgrouped_fp8i4.cpp
gemm_bquant_quantgrouped_bf16mxfp4.cpp
# gemm_bquant_quantgrouped_bf16mxfp4.cpp
# gemm_bquant_quantgrouped_bf8.cpp
# gemm_bquant_quantgrouped_fp8.cpp
# gemm_bquant_quantgrouped_preshuffleb.cpp

View File

@@ -35,30 +35,45 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 32, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant",
// "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {

View File

@@ -33,14 +33,14 @@ auto create_args(int argc, char* argv[])
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"bf8i4 or bf16fp4")
.insert("warmup", "1", "Number of iterations before benchmarking the kernel")
.insert("repeat", "0", "Number of iterations to benchmark the kernel")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "SplitK value")
.insert("device", "0", "Device id that will be used to run the kernel")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "Flush cache before running the kernel")
.insert("rotating_count", "0", "Rotating count")
.insert("rotating_count", "1000", "Rotating count")
.insert("quant_mode", "bquant", "Choose aquant, bquant, tensor or rowcol")
.insert("preshuffleb", "false", "Enable preshuffle of tensor B")
.insert("preshufflequant", "false", "Enable preshuffle of quant tensor")

View File

@@ -357,9 +357,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
<< " QuantMode = " << quant_type_to_string(QuantMode)
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
<< " PreshuffleB = \n"
<< (GemmConfig::PreshuffleB ? "true" : "false") << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}

View File

@@ -349,6 +349,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
constexpr index_t reg_offset = nIter;
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
@@ -368,19 +369,20 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
printf("block_id: %d, warp_id: %d, thread_id(): %d, nIter: %d, lane_id(): "
"%u, kQScale: %d, pull_from_lane: %u, scale_reg: %f, "
"scale_reg_f: %f\n",
get_block_id(),
get_warp_id(),
get_thread_id(),
static_cast<int>(nIter),
__lane_id(),
static_cast<int>(kQScale),
pull_from_lane,
scale_reg,
scale_reg_f);
// printf("block_id: %d, warp_id: %d, thread_id(): %d, nIter: %d,
// lane_id(): "
// "%u, kQScale: %d, pull_from_lane: %u, scale_reg: %f, "
// "scale_reg_f: %f\n",
// get_block_id(),
// get_warp_id(),
// get_thread_id(),
// static_cast<int>(nIter),
// __lane_id(),
// static_cast<int>(kQScale),
// pull_from_lane,
// scale_reg,
// scale_reg_f);
static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
[&](auto c_row) {
c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=

View File

@@ -298,15 +298,15 @@ struct QuantGemmKernel
const auto bq_x = N * KPerBlockBQ; // 2x2 = 4
const auto bq_y = QK_B / KPerBlockBQ; // 4/2 = 2
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("N:%d, QK_B:%d\n", N, QK_B);
printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n",
bq_x,
bq_y,
GetVectorSizeBQ,
KPerBlockBQ);
}
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("N:%d, QK_B:%d\n", N, QK_B);
// printf("bq_x: %d, bq_y: %d, getVectorSizeBQ: %d, kPerBlockBQ: %d\n",
// bq_x,
// bq_y,
// GetVectorSizeBQ,
// KPerBlockBQ);
// }
const auto bq_desc = make_naive_tensor_descriptor(make_tuple(bq_y, bq_x),
make_tuple(bq_x, 1),
@@ -319,10 +319,10 @@ struct QuantGemmKernel
// each thread block can process complete tiles without edge cases
const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; // 2x2 = 4
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("block_tile_size:%d \n", block_tile_size);
}
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("block_tile_size:%d \n", block_tile_size);
// }
const auto bq_pad0_desc = transform_tensor_descriptor(
bq_desc,
@@ -337,22 +337,25 @@ struct QuantGemmKernel
// Split the X dimension into [wave_tile_count_x, wave_tile_size]
// This separates the work into tiles that can be processed by
// individual warps/waves
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 4
const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; // 32/16 x 2 = 4 = 2
const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; // 4
const auto wave_tile_size =
((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1 /*QN_B/WarpTileN*/) *
KPerBlockBQ; // 32/16 x 2 = 4 = 2
const auto wave_tile_count_x =
ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); // 4/4 = 1 ==2
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size: "
"%d, wave_tile_count_x: %d\n",
pad_bq_x,
WarpTileN,
NPerBlockBQ,
KPerBlockBQ,
wave_tile_size,
wave_tile_count_x);
}
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:
// "
// "%d, wave_tile_count_x: %d\n",
// pad_bq_x,
// WarpTileN,
// NPerBlockBQ,
// KPerBlockBQ,
// wave_tile_size,
// wave_tile_count_x);
// }
const auto bq_unmerge_pad0_desc = transform_tensor_descriptor(
bq_pad0_desc,
@@ -383,13 +386,16 @@ struct QuantGemmKernel
// where merged_outer_dim = bq_y * wave_tile_count_x
// This layout facilitates efficient block-to-data mapping
const auto pad_wave_size = ck_tile::integer_least_multiple(wave_tile_size, get_warp_size());
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("pad_wave_size:%d\n", pad_wave_size);
}
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("pad_wave_size:%d\n", pad_wave_size);
// printf("Final bq tensor lengths: %d x %d \n",
// bq_y * wave_tile_count_x,
// pad_wave_size);
// }
const auto bq_merge_pad1_desc = transform_tensor_descriptor(
bq_pad1_desc,
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 2
make_tuple(make_merge_transform(make_tuple(bq_y, wave_tile_count_x)), // 4
make_pass_through_transform(pad_wave_size)), // 64
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
@@ -1115,13 +1121,33 @@ struct QuantGemmKernel
if constexpr(PreshuffleQuant)
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN;
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto block_n =
TilePartitioner::NPerBlock / QuantGroupSize::kN; // 64 / 32 = 2
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1);
constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
constexpr auto tile_window_width = ck_tile::integer_least_multiple(
warp_n * bqk_per_block, get_warp_size()); // 128
constexpr auto tile_window_height = block_n / warp_n; // 2
auto block_n_idx = i_n / block_n;
warp_n * bqk_per_block, get_warp_size()); // 128
constexpr auto tile_window_height =
min(block_n,
TilePartitioner::BlockGemmShape::BlockWarps::at(
I1)); // block_n / warp_n; // 2 / 4 = 0
auto block_n_idx = i_n / TilePartitioner::NPerBlock; // 0,1,2
// if(get_thread_id() == 0)
// {
// printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
// printf("block_id: %d, block_n: %d, warp_n: %d, bqk_per_block: %d,
// block_n_idx: %d, "
// "tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
// get_block_id(),
// static_cast<int>(block_n),
// static_cast<int>(warp_n),
// static_cast<int>(bqk_per_block),
// static_cast<int>(block_n_idx),
// tile_window_width,
// static_cast<int>(tile_window_height),
// static_cast<int>(i_n));
// }
return make_tile_window(
bq_pad_view,
@@ -1226,15 +1252,15 @@ struct QuantGemmKernel
{
n = kargs.N;
}
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n");
// To print Tile window after bq_pad0_desc
// bq_block_window.template print_tile_window_range<BQDataType>(
// 0, 128, 0, 2, "bq block window");
bq_block_window.template print_tile_window_range<BQDataType>(
0, 8, 0, 64, "bq block window");
}
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n");
// // To print Tile window after bq_pad0_desc
// // bq_block_window.template print_tile_window_range<BQDataType>(
// // 0, 128, 0, 2, "bq block window");
// bq_block_window.template print_tile_window_range<BQDataType>(
// 0, 8, 0, 64, "bq block window");
// }
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
}

View File

@@ -196,7 +196,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t n,
[[maybe_unused]] index_t n,
index_t num_loop,
void* p_smem) const
{
@@ -280,9 +280,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
(PreshuffleQuant)
? make_array(((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0)
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);

View File

@@ -245,56 +245,70 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
if constexpr(NPerQ <= WarpGemm::kN)
{
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2
constexpr auto N0 = WarpGemm::kN / NPerQ; //BlockGemmShape::kN / KPerQ; // 1
constexpr auto N0 = WarpGemm::kN / NPerQ; // BlockGemmShape::kN / KPerQ; // 1
constexpr auto N2 = 1;
constexpr auto NR1 = NPerQ; // 16
constexpr auto NR1 = NPerQ; // 16
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*16)=2
constexpr auto K1 = NWarps; // 4
constexpr auto K0 = KPerTile / K1; // 1
constexpr auto K1 = NWarps; // 4
constexpr auto K0 = KPerTile / K1; // 1
constexpr auto KR = 1;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<MWarps, NR0, NR1, KR>,
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 1>, sequence<0, 2, 0, 2, 0>>, // (Mwarp, Nwarp),(XR0, X0, XR1, X1, YR)
tuple<sequence<0, 1>, sequence<0, 2, 0, 2, 0>>, // (Mwarp, Nwarp),(XR0, X0,
// XR1, X1, YR)
tuple<sequence<0, 1>, sequence<1, 0, 2, 1, 3>>, // (1, 4), (2, 1, 16, 2, 1)
sequence<1, 2>,
sequence<0, 2>>{});
}
else if constexpr(NPerQ <= WarpGemm::kN * NWarps)
else if constexpr(NPerQ < WarpGemm::kN * NWarps)
{
constexpr auto KR = NPerQ / WarpGemm::kN; // Scale replication factor 32/16 = 2
constexpr auto K1 = NWarps / KR; // Warps per unique scale 4/2 = 2
constexpr auto K0 = KPerTile / K1; // Iterations to cover N dimension 4/2 = 2
constexpr auto KR = NPerQ / WarpGemm::kN; // Scale replication factor 64/16 = 4
constexpr auto K1 = NWarps / KR; // Warps per unique scale 4/4 = 1
constexpr auto K0 = KPerTile / K1; // Iterations to cover N dimension 4/1 = 4
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2
constexpr auto N0 = 1; //NPerQ/WarpGemm::kN; // 2
constexpr auto N0 = 1; // NPerQ/WarpGemm::kN; // 1
constexpr auto N2 = 1;
constexpr auto NR1 = NPerQ; // 32
constexpr auto NR1 = NPerQ; // 32
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*32)=1
if(get_block_id() == 0 && get_thread_id() == 0)
{
// Debug print to verify values
printf("PreshuffleQuant Medium-grained: MWarps: %d, K1=%d, KR=%d, get_warp_size(): %d, K0=%d, N0=%d\n",
MWarps,
K1,
KR,
get_warp_size(),
K0,
N0);
}
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// // Debug print to verify values
// printf("PreshuffleQuant Medium-grained: MWarps: %d, K1=%d, KR=%d,
// get_warp_size(): %d, K0=%d, N0=%d\n",
// MWarps,
// K1,
// KR,
// get_warp_size(),
// K0,
// N0);
// }
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 1, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1, 3>, sequence<1, 0, 2, 1>>,
sequence<1, 2>,
sequence<0, 2>>{});
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 1, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1, 3>, sequence<1, 0, 2, 1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2
constexpr auto N0 = 1; // NPerQ/WarpGemm::kN; // 1
constexpr auto N2 = 1;
constexpr auto NR1 = 32; // 32
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*32)=1
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
tuple<sequence<0, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1>, sequence<2, 0, 3, 1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
}
else