fix group 128 for both decode and prefill shapes

This commit is contained in:
khuagarw
2025-12-19 01:05:01 +00:00
parent 39afffac85
commit 5b4a67ec6d
10 changed files with 148 additions and 89 deletions

View File

@@ -18,7 +18,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
# gemm_bquant_quantgrouped_fp8i4.cpp
# gemm_bquant_quantgrouped_bf16mxfp4.cpp
# gemm_bquant_quantgrouped_bf8.cpp
# gemm_bquant_quantgrouped_fp8.cpp
gemm_bquant_quantgrouped_fp8.cpp
# gemm_bquant_quantgrouped_preshuffleb.cpp
gemm_bquant_quantgrouped_preshufflequant.cpp
# gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;
using GemmConfig = GemmConfigQuantDecode<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
@@ -49,4 +49,10 @@ void bquant_quantgrouped_fp8_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings(
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
}

View File

@@ -4,8 +4,8 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>; // GemmConfigPreshuffleQuantDecode<T>;
// //GemmConfigPreshuffleBQuantPrefill<T>;
using GemmConfig = GemmConfigPreshuffleQuantDecode<T>;
// //GemmConfigPreshuffleBQuantPrefill<T>;
void bquant_quantgrouped_preshufflequant_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
@@ -24,31 +24,33 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant",
// "1x8x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
// ck_tile::fp8_t,
// ck_tile::half_t,
// float>{});
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
// lut[hash_multiple_strings({"fp8",
// "bquant",
// "non-preshuffleb",
// "preshufflequant",
// "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig =
// decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t,
// float>{});
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
// };
lut[hash_multiple_strings({"fp8",
"bquant",
"non-preshuffleb",
@@ -75,20 +77,19 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"fp8",
// "bquant",
// "non-preshuffleb",
// "preshufflequant",
// "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
// using TypeConfig =
// decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t,
// float>{});
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
// TypeConfig,
// QuantGroupSize,
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
//};
lut[hash_multiple_strings({"fp8",
"bquant",
"non-preshuffleb",
"preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(arg_parser);
};
// lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant",
// "1x1x128"})] =
// [](const ck_tile::ArgParser& arg_parser) {

View File

@@ -89,8 +89,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void aquant_quantgrouped_preshufflequant_instance_factory(
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void bquant_quantgrouped_fp8_instance_factory(
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
void bquant_quantgrouped_fp8_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void bquant_quantgrouped_bf8_instance_factory(
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
// void bquant_quantgrouped_fp8i4_instance_factory(
@@ -126,7 +126,7 @@ int main(int argc, char* argv[])
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
// aquant_quantgrouped_instance_factory(lut);
// aquant_quantgrouped_preshufflequant_instance_factory(lut);
// bquant_quantgrouped_fp8_instance_factory(lut);
bquant_quantgrouped_fp8_instance_factory(lut);
// bquant_quantgrouped_bf8_instance_factory(lut);
// bquant_quantgrouped_fp8i4_instance_factory(lut);
// bquant_quantgrouped_bf8i4_instance_factory(lut);

View File

@@ -349,7 +349,24 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
if constexpr(PreshuffleQuant)
{
constexpr index_t reg_offset = nIter;
// constexpr index_t reg_offset = nIter;
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::QuantGroupSize::kN >=
(NWarp * WarpGemm::kN))
{
if constexpr(Traits::NPerBlock ==
GemmTraits::QuantGroupSize::kN)
return kQScale;
else
return nIter; // for prefill needs kQscale, for decode needs
// nIter
}
else
{
return nIter;
}
}();
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
@@ -373,13 +390,14 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
// printf("block_id: %d, warp_id: %d, thread_id(): %d, nIter: %d,
// lane_id(): "
// reg_offset: %d, lane_id(): "
// "%u, kQScale: %d, pull_from_lane: %u, scale_reg: %f, "
// "scale_reg_f: %f\n",
// get_block_id(),
// get_warp_id(),
// get_thread_id(),
// static_cast<int>(nIter),
// reg_offset,
// __lane_id(),
// static_cast<int>(kQScale),
// pull_from_lane,

View File

@@ -346,7 +346,7 @@ struct QuantGemmKernel
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:"
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockBQ: %d, KPerBlockBQ: %d, wave_tile_size:"
"%d, wave_tile_count_x: %d\n",
pad_bq_x,
WarpTileN,
@@ -1119,7 +1119,9 @@ struct QuantGemmKernel
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
constexpr auto block_n =
TilePartitioner::NPerBlock / QuantGroupSize::kN; // 128/32 = 4
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
? TilePartitioner::NPerBlock / QuantGroupSize::kN
: QuantGroupSize::kN / TilePartitioner::NPerBlock; // 128/32 = 4
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); // 16
constexpr auto warpPerGroup = (QuantGroupSize::kN < warp_n)
@@ -1133,29 +1135,47 @@ struct QuantGemmKernel
? block_n / warpPerGroup
: block_n; // 4 / 2 = 2
auto block_n_idx = i_n / TilePartitioner::NPerBlock; // 0,1,2
if(get_thread_id() == 0)
auto block_n_idx =
i_n /
TilePartitioner::NPerBlock; // 0,1,2 (i_n - TilePartitioner::NPerBlock) /
// TilePartitioner::NPerBlock
// For decode shapes GN: 128, Blocks needs to access 0,0,1,1,2,2 ...
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
{
printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
printf("block_id: %d, block_n: %d, warp_n: %d, warpPerGroup: %d, "
"bqk_per_block: %d, block_n_idx: %d, "
"tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
get_block_id(),
static_cast<int>(block_n),
static_cast<int>(warp_n),
static_cast<int>(warpPerGroup),
static_cast<int>(bqk_per_block),
static_cast<int>(block_n_idx),
tile_window_width,
static_cast<int>(tile_window_height),
static_cast<int>(i_n));
block_n_idx = block_n_idx >> 1;
}
// if(get_thread_id() == 0)
// {
// printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
// printf("block_id: %d, block_n: %d, warp_n: %d, warpPerGroup: %d, "
// "bqk_per_block: %d, block_n_idx: %d, "
// "tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
// get_block_id(),
// static_cast<int>(block_n),
// static_cast<int>(warp_n),
// static_cast<int>(warpPerGroup),
// static_cast<int>(bqk_per_block),
// static_cast<int>(block_n_idx),
// tile_window_width,
// static_cast<int>(tile_window_height),
// static_cast<int>(i_n));
// }
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
{
return make_tile_window(
bq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_n_idx, 0});
}
else
{
return make_tile_window(
bq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_n_idx * tile_window_height,
0}); // normally needs block_n_idx * tile_window_height, for decode GN
// : 128 needs 0,0, 1,1, 2,2 ...
}
return make_tile_window(
bq_pad_view,
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
{block_n_idx * tile_window_height, 0});
}
else
{
@@ -1255,15 +1275,15 @@ struct QuantGemmKernel
{
n = kargs.N;
}
if(get_block_id() == 0 && get_thread_id() == 0)
{
printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n");
// To print Tile window after bq_pad0_desc
// bq_block_window.template print_tile_window_range<BQDataType>(
// 0, 128, 0, 2, "bq block window");
bq_block_window.template print_tile_window_range<BQDataType>(
0, 8, 0, 32, "bq block window");
}
// if(get_block_id() == 0 && get_thread_id() == 0)
// {
// printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n");
// // To print Tile window after bq_pad0_desc
// // bq_block_window.template print_tile_window_range<BQDataType>(
// // 0, 128, 0, 2, "bq block window");
// bq_block_window.template print_tile_window_range<BQDataType>(
// 0, 8, 0, 16, "bq block window");
// }
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
}

View File

@@ -29,11 +29,11 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
// static_assert(NPerBlockBQ == 1, "NPerBlock must be >= QuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
static_assert(NPerBlock % QuantGroupSize::kN == 0,
"NPerBlock must be a multiple of QuantGroupSize::kN");
// static_assert(NPerBlock % QuantGroupSize::kN == 0,
// "NPerBlock must be a multiple of QuantGroupSize::kN");
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize::kK");

View File

@@ -65,7 +65,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
static constexpr index_t NPerBlockBQ =
(QuantGroupSize::kN <= BlockGemmShape::kN) ? BlockGemmShape::kN / QuantGroupSize::kN : 1;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }

View File

@@ -297,14 +297,27 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
}
else
{
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 1
constexpr auto N0 = 1; // NPerQ/WarpGemm::kN; // 1
constexpr auto N2 = 1;
constexpr auto NR1 = 32; // 32
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*32)=1
constexpr auto NR1 = 32; // NPerQ; // 32
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*1*1*32)=1
constexpr auto K1 = KPerTile;
if(get_block_id() == 0 && get_thread_id() == 0)
{
// Debug print to verify values
printf("PreshuffleQuant Fine-grained: KPerQ: %d, NPerQ: %d, N1=%d, NR0=%d, "
"KPerTile: %d \n",
KPerQ,
NPerQ,
N1,
NR0,
KPerTile);
}
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
tuple<sequence<K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1>, sequence<2, 0, 3, 1>>,
sequence<1, 2>,

View File

@@ -73,7 +73,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
static constexpr auto TailNum = TailNum_;
static_assert(BlockGemmShape::kM % QuantGroupSize::kM == 0);
static_assert(BlockGemmShape::kN % QuantGroupSize::kN == 0);
// static_assert(BlockGemmShape::kN % QuantGroupSize::kN == 0);
static_assert(BlockGemmShape::kK % QuantGroupSize::kK == 0);
[[nodiscard]] CK_TILE_HOST static const std::string GetName()