mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
fix group 128 for both decode and prefill shapes
This commit is contained in:
@@ -18,7 +18,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
# gemm_bquant_quantgrouped_fp8i4.cpp
|
||||
# gemm_bquant_quantgrouped_bf16mxfp4.cpp
|
||||
# gemm_bquant_quantgrouped_bf8.cpp
|
||||
# gemm_bquant_quantgrouped_fp8.cpp
|
||||
gemm_bquant_quantgrouped_fp8.cpp
|
||||
# gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
gemm_bquant_quantgrouped_preshufflequant.cpp
|
||||
# gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
@@ -49,4 +49,10 @@ void bquant_quantgrouped_fp8_instance_factory(
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 64, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"fp8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x128x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigPreshuffleBQuantPrefill<T>; // GemmConfigPreshuffleQuantDecode<T>;
|
||||
// //GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigPreshuffleQuantDecode<T>;
|
||||
// //GemmConfigPreshuffleBQuantPrefill<T>;
|
||||
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
@@ -24,31 +24,33 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
|
||||
lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
// lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant",
|
||||
// "1x8x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
// ck_tile::fp8_t,
|
||||
// ck_tile::half_t,
|
||||
// float>{});
|
||||
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 8, 128>>;
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
// lut[hash_multiple_strings({"fp8",
|
||||
// "bquant",
|
||||
// "non-preshuffleb",
|
||||
// "preshufflequant",
|
||||
// "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig =
|
||||
// decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t,
|
||||
// float>{});
|
||||
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 16, 128>>;
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
// };
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
@@ -75,20 +77,19 @@ void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
// lut[hash_multiple_strings({"fp8",
|
||||
// "bquant",
|
||||
// "non-preshuffleb",
|
||||
// "preshufflequant",
|
||||
// "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
// using TypeConfig =
|
||||
// decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t,
|
||||
// float>{});
|
||||
// using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
// return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
// TypeConfig,
|
||||
// QuantGroupSize,
|
||||
// ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
//};
|
||||
lut[hash_multiple_strings({"fp8",
|
||||
"bquant",
|
||||
"non-preshuffleb",
|
||||
"preshufflequant",
|
||||
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
TypeConfig,
|
||||
QuantGroupSize,
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
};
|
||||
// lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant",
|
||||
// "1x1x128"})] =
|
||||
// [](const ck_tile::ArgParser& arg_parser) {
|
||||
|
||||
@@ -89,8 +89,8 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser)
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void aquant_quantgrouped_preshufflequant_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_fp8_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_fp8_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_bf8_instance_factory(
|
||||
// std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
// void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
@@ -126,7 +126,7 @@ int main(int argc, char* argv[])
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>> lut;
|
||||
// aquant_quantgrouped_instance_factory(lut);
|
||||
// aquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
// bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8_instance_factory(lut);
|
||||
// bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
// bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
// bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
|
||||
@@ -349,7 +349,24 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
|
||||
if constexpr(PreshuffleQuant)
|
||||
{
|
||||
constexpr index_t reg_offset = nIter;
|
||||
// constexpr index_t reg_offset = nIter;
|
||||
constexpr index_t reg_offset = [&]() {
|
||||
if constexpr(GemmTraits::QuantGroupSize::kN >=
|
||||
(NWarp * WarpGemm::kN))
|
||||
{
|
||||
if constexpr(Traits::NPerBlock ==
|
||||
GemmTraits::QuantGroupSize::kN)
|
||||
return kQScale;
|
||||
else
|
||||
return nIter; // for prefill needs kQscale, for decode needs
|
||||
// nIter
|
||||
}
|
||||
else
|
||||
{
|
||||
return nIter;
|
||||
}
|
||||
}();
|
||||
|
||||
auto pull_from_lane =
|
||||
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;
|
||||
|
||||
@@ -373,13 +390,14 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
float scale_reg_f = Base::cvt_scale_to_fp32(gathered_scale_reg);
|
||||
|
||||
// printf("block_id: %d, warp_id: %d, thread_id(): %d, nIter: %d,
|
||||
// lane_id(): "
|
||||
// reg_offset: %d, lane_id(): "
|
||||
// "%u, kQScale: %d, pull_from_lane: %u, scale_reg: %f, "
|
||||
// "scale_reg_f: %f\n",
|
||||
// get_block_id(),
|
||||
// get_warp_id(),
|
||||
// get_thread_id(),
|
||||
// static_cast<int>(nIter),
|
||||
// reg_offset,
|
||||
// __lane_id(),
|
||||
// static_cast<int>(kQScale),
|
||||
// pull_from_lane,
|
||||
|
||||
@@ -346,7 +346,7 @@ struct QuantGemmKernel
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockPQ: %d, KPerBlockBQ: %d, wave_tile_size:"
|
||||
printf("pad_bq_x:%d, WarpTileN:%d, NPerBlockBQ: %d, KPerBlockBQ: %d, wave_tile_size:"
|
||||
"%d, wave_tile_count_x: %d\n",
|
||||
pad_bq_x,
|
||||
WarpTileN,
|
||||
@@ -1119,7 +1119,9 @@ struct QuantGemmKernel
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr auto block_n =
|
||||
TilePartitioner::NPerBlock / QuantGroupSize::kN; // 128/32 = 4
|
||||
(QuantGroupSize::kN <= TilePartitioner::NPerBlock)
|
||||
? TilePartitioner::NPerBlock / QuantGroupSize::kN
|
||||
: QuantGroupSize::kN / TilePartitioner::NPerBlock; // 128/32 = 4
|
||||
|
||||
constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); // 16
|
||||
constexpr auto warpPerGroup = (QuantGroupSize::kN < warp_n)
|
||||
@@ -1133,29 +1135,47 @@ struct QuantGemmKernel
|
||||
? block_n / warpPerGroup
|
||||
: block_n; // 4 / 2 = 2
|
||||
|
||||
auto block_n_idx = i_n / TilePartitioner::NPerBlock; // 0,1,2
|
||||
|
||||
if(get_thread_id() == 0)
|
||||
auto block_n_idx =
|
||||
i_n /
|
||||
TilePartitioner::NPerBlock; // 0,1,2 (i_n - TilePartitioner::NPerBlock) /
|
||||
// TilePartitioner::NPerBlock
|
||||
// For decode shapes GN: 128, Blocks needs to access 0,0,1,1,2,2 ...
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
|
||||
printf("block_id: %d, block_n: %d, warp_n: %d, warpPerGroup: %d, "
|
||||
"bqk_per_block: %d, block_n_idx: %d, "
|
||||
"tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
|
||||
get_block_id(),
|
||||
static_cast<int>(block_n),
|
||||
static_cast<int>(warp_n),
|
||||
static_cast<int>(warpPerGroup),
|
||||
static_cast<int>(bqk_per_block),
|
||||
static_cast<int>(block_n_idx),
|
||||
tile_window_width,
|
||||
static_cast<int>(tile_window_height),
|
||||
static_cast<int>(i_n));
|
||||
block_n_idx = block_n_idx >> 1;
|
||||
}
|
||||
// if(get_thread_id() == 0)
|
||||
// {
|
||||
// printf("In MakeGemmTileWindows for BQ with PreshuffleQuant\n");
|
||||
// printf("block_id: %d, block_n: %d, warp_n: %d, warpPerGroup: %d, "
|
||||
// "bqk_per_block: %d, block_n_idx: %d, "
|
||||
// "tile_window_width: %d, tile_window_height: %d, i_n: %d\n",
|
||||
// get_block_id(),
|
||||
// static_cast<int>(block_n),
|
||||
// static_cast<int>(warp_n),
|
||||
// static_cast<int>(warpPerGroup),
|
||||
// static_cast<int>(bqk_per_block),
|
||||
// static_cast<int>(block_n_idx),
|
||||
// tile_window_width,
|
||||
// static_cast<int>(tile_window_height),
|
||||
// static_cast<int>(i_n));
|
||||
// }
|
||||
if(QuantGroupSize::kN > TilePartitioner::NPerBlock)
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height,
|
||||
0}); // normally needs block_n_idx * tile_window_height, for decode GN
|
||||
// : 128 needs 0,0, 1,1, 2,2 ...
|
||||
}
|
||||
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<tile_window_height>{}, number<tile_window_width>{}),
|
||||
{block_n_idx * tile_window_height, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1255,15 +1275,15 @@ struct QuantGemmKernel
|
||||
{
|
||||
n = kargs.N;
|
||||
}
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n");
|
||||
// To print Tile window after bq_pad0_desc
|
||||
// bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
// 0, 128, 0, 2, "bq block window");
|
||||
bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
0, 8, 0, 32, "bq block window");
|
||||
}
|
||||
// if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
// {
|
||||
// printf("In RunGemm, before GemmPipeline call for BQuantGrouped\n");
|
||||
// // To print Tile window after bq_pad0_desc
|
||||
// // bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
// // 0, 128, 0, 2, "bq block window");
|
||||
// bq_block_window.template print_tile_window_range<BQDataType>(
|
||||
// 0, 8, 0, 16, "bq block window");
|
||||
// }
|
||||
return GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n);
|
||||
}
|
||||
|
||||
@@ -29,11 +29,11 @@ struct GemmBQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
// static_assert(NPerBlockBQ == 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
|
||||
static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
// static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
// "NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
|
||||
|
||||
@@ -65,7 +65,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
|
||||
static constexpr index_t NPerBlockBQ =
|
||||
(QuantGroupSize::kN <= BlockGemmShape::kN) ? BlockGemmShape::kN / QuantGroupSize::kN : 1;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
|
||||
@@ -297,14 +297,27 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 2
|
||||
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // 1
|
||||
constexpr auto N0 = 1; // NPerQ/WarpGemm::kN; // 1
|
||||
constexpr auto N2 = 1;
|
||||
constexpr auto NR1 = 32; // 32
|
||||
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*2*1*32)=1
|
||||
constexpr auto NR1 = 32; // NPerQ; // 32
|
||||
constexpr auto NR0 = warp_size / (N0 * N1 * N2 * NR1); // 64/(1*1*1*32)=1
|
||||
constexpr auto K1 = KPerTile;
|
||||
|
||||
if(get_block_id() == 0 && get_thread_id() == 0)
|
||||
{
|
||||
// Debug print to verify values
|
||||
printf("PreshuffleQuant Fine-grained: KPerQ: %d, NPerQ: %d, N1=%d, NR0=%d, "
|
||||
"KPerTile: %d \n",
|
||||
KPerQ,
|
||||
NPerQ,
|
||||
N1,
|
||||
NR0,
|
||||
KPerTile);
|
||||
}
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
|
||||
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<K1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2, 0, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 0, 3, 1>>,
|
||||
sequence<1, 2>,
|
||||
|
||||
@@ -73,7 +73,7 @@ struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase<ADataType_,
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static_assert(BlockGemmShape::kM % QuantGroupSize::kM == 0);
|
||||
static_assert(BlockGemmShape::kN % QuantGroupSize::kN == 0);
|
||||
// static_assert(BlockGemmShape::kN % QuantGroupSize::kN == 0);
|
||||
static_assert(BlockGemmShape::kK % QuantGroupSize::kK == 0);
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
|
||||
Reference in New Issue
Block a user