Fix and improve the gemm quant pipeline infrastructure (#3245)

This commit is contained in:
Thomas Ning
2025-11-26 18:04:27 -08:00
committed by GitHub
parent 79aae7c7f7
commit a38aeceb21
11 changed files with 96 additions and 272 deletions

View File

@@ -214,22 +214,27 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
// ---- Lower 4 int4 values (even positions) ----
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
uint32_t dict_sel = a & 0x07070707;
uint32_t sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
uint32_t sign = a >> 1;
// Build final selector:
// - bit 2 of each byte (0x04) selects negative vs positive table
// - 0x03020100 selects byte lanes [0,1,2,3] in order
final_sel = (sign & 0x04040404) | 0x03020100;
// Lookup positive and negative fp8 codes from the small register tables.
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
// ---- Upper 4 int4 values (odd positions) ----
// Shift to bring the high-nibble int4s into place and repeat the process.
a >>= 4;
dict_sel = a & 0x07070707;
sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
dict_sel = a & 0x07070707;
sign = a >> 1;
final_sel = (sign & 0x04040404) | 0x03020100;
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
@@ -306,22 +311,29 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
// ---- Lower 4 int4 values (even positions) ----
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
uint32_t dict_sel = a & 0x07070707;
uint32_t sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
uint32_t sign = a >> 1;
// Build final selector:
// - bit 2 of each byte (0x04) selects negative vs positive table
// - 0x03020100 selects byte lanes [0,1,2,3] in order
final_sel = (sign & 0x04040404) | 0x03020100;
// Lookup positive and negative fp8 codes from the small register tables.
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
// ---- Upper 4 int4 values (odd positions) ----
// Shift to bring the high-nibble int4s into place and repeat the process.
a >>= 4;
dict_sel = a & 0x07070707;
sign = a >> 1;
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(final_sel)
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
dict_sel = a & 0x07070707;
sign = a >> 1;
final_sel = (sign & 0x04040404) | 0x03020100;
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);

View File

@@ -30,7 +30,7 @@ struct BaseGemmPipelineAgBgCrCompV3
{
if(BlockHasHotloop(num_loop))
{
return TailNumber::Full;
return TailNumber::Odd;
}
else
{
@@ -52,23 +52,27 @@ struct BaseGemmPipelineAgBgCrCompV3
// Handle all the valid cases.
if(has_hot_loop)
{
if(tail_number == TailNumber::Full)
if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Full>{});
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
}
else
{
if(tail_number == TailNumber::Odd)
if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Odd>{});
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == TailNumber::Even)
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Even>{});
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
}
#if defined(__HIP_DEVICE_COMPILE__)
@@ -76,16 +80,8 @@ struct BaseGemmPipelineAgBgCrCompV3
__builtin_unreachable();
#else
// If execution reaches here, it's an invalid combination of arguments.
if(has_hot_loop)
{
throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must "
"be TailNumber::Full.");
}
else
{
throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must "
"be TailNumber::Odd or TailNumber::Even.");
}
throw std::logic_error("Invalid TailNumber value: must be "
"TailNumber::Odd or TailNumber::Even");
#endif
}
};
@@ -588,7 +584,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
} while(i < (num_loop - 1));
}
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
if constexpr(TailNum == TailNumber::Odd)
{
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency

View File

@@ -786,8 +786,8 @@ struct QuantGemmKernel
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_naive_tensor_view<address_space_enum::global>(
bq_ptr,
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
make_tuple(1, kargs.stride_BQ),
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
make_tuple(kargs.stride_BQ, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}
@@ -1030,9 +1030,9 @@ struct QuantGemmKernel
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
return make_tile_window(
bq_pad_view,
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
{0, i_n / QuantGroupSize::kN});
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
{i_n / QuantGroupSize::kN, 0});
}
}
else

View File

@@ -15,68 +15,9 @@
namespace ck_tile {
template <typename Problem>
struct BaseAQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop % BaseGemmPipelineAgBgCrCompV3<Problem>::PrefetchStages == 0)
{
return TailNumber::Even;
}
else
{
return TailNumber::Odd;
}
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
if(has_hot_loop)
{
if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
else
{
if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
}
};
// ToDo: Change the Pipeline to actual memory pipeline.
template <typename Problem, typename Policy = GemmAQuantPipelineAgBgCrDefaultPolicy>
struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Problem>
struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;

View File

@@ -14,74 +14,8 @@
namespace ck_tile {
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem>
struct BaseAQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
if(has_hot_loop)
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
else
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
}
};
template <typename Problem, typename Policy = GemmAQuantPipelineAgBgCrDefaultPolicy>
struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV3<Problem>
struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;

View File

@@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
KPerBlockBQ,
NPerBlockBQ,
KPerBlockBQ,
Problem::QuantGroupSize::kN>;
return TileEncodingPattern::make_2d_static_tile_distribution();

View File

@@ -20,68 +20,8 @@ namespace ck_tile {
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem>
struct BaseBQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
if(has_hot_loop)
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
else
{
if(tail_number == ck_tile::TailNumber::Full)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_number == ck_tile::TailNumber::Odd)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_number == ck_tile::TailNumber::Even)
{
return run_func(
ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Unsupported tail number for this operation !!!");
}
}
}
};
template <typename Problem, typename Policy = GemmBQuantPipelineAgBgCrDefaultPolicy>
struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV3<Problem>
struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase<Problem, Policy>;
@@ -318,8 +258,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
: is_bq_col_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);
: is_bq_col_major ? make_array(0, KPerBlockBQ)
: make_array(KPerBlockBQ, 0);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);

View File

@@ -171,7 +171,7 @@ template <typename BlockGemmShape,
index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t XPerQ,
index_t YPerQ,
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
{
@@ -231,39 +231,39 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
}
else
{
if constexpr(XPerQ < WarpGemm::kN)
if constexpr(YPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
constexpr index_t Y = YPerTile; // Full Y dimension of tile
constexpr index_t YR = 1; // No Y replication needed
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t X1 = NWarps; // Number of warps in N-dim
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
constexpr index_t XR = XPerQ; // Elements per quantization group
constexpr index_t X = XPerTile; // Full X dimension of tile
constexpr index_t XR = 1; // No Y replication needed
constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim
constexpr index_t Y1 = NWarps; // Number of warps in N-dim
constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp
constexpr index_t YR = YPerQ; // Elements per quantization group
static_assert(X0 * X1 * X2 == XPerTile,
"X0, X1, X2 must cover the blocktile along X.");
static_assert(Y0 * Y1 * Y2 == YPerTile,
"Y0, Y1, Y2 must cover the blocktile along Y.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, YR, XR>,
tuple<sequence<Y>, sequence<X0, X1, X2>>,
tuple<sequence<0, 2>, sequence<0, 2, 0>>,
tile_distribution_encoding<sequence<MWarps, XR, YR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<0, 1>, sequence<0, 1, 0>>,
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
else if constexpr(YPerQ <= WarpGemm::kN * NWarps)
{
// Case 2: Medium-grained - one quantization scale per warp
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto X1 = NWarps / XR; // Warps per unique scale
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor
constexpr auto Y1 = NWarps / YR; // Warps per unique scale
constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
tuple<sequence<YPerTile>, sequence<X0, X1>>,
tuple<sequence<0, 2, 0>, sequence<0>>,
tile_distribution_encoding<sequence<MWarps, YR, get_warp_size()>,
tuple<sequence<Y0, Y1>, sequence<XPerTile>>,
tuple<sequence<0, 1, 0>, sequence<0>>,
tuple<sequence<0, 1, 1>, sequence<2>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{});
}
else // XPerQ > WarpGemm::kN * NWarps

View File

@@ -280,7 +280,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
}
else
{
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
}
// Prefill A0
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
@@ -338,7 +338,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
}
else
{
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
}
// Prefill A(2i+1)
@@ -390,7 +390,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
}
else
{
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
}
// Prefill A(2i+2)