mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Fix and improve the gemm quant pipeline infrastructure (#3245)
This commit is contained in:
@@ -214,22 +214,27 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
|
||||
|
||||
// ---- Lower 4 int4 values (even positions) ----
|
||||
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
|
||||
uint32_t dict_sel = a & 0x07070707;
|
||||
uint32_t sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
|
||||
uint32_t sign = a >> 1;
|
||||
// Build final selector:
|
||||
// - bit 2 of each byte (0x04) selects negative vs positive table
|
||||
// - 0x03020100 selects byte lanes [0,1,2,3] in order
|
||||
final_sel = (sign & 0x04040404) | 0x03020100;
|
||||
// Lookup positive and negative fp8 codes from the small register tables.
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
|
||||
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
// ---- Upper 4 int4 values (odd positions) ----
|
||||
// Shift to bring the high-nibble int4s into place and repeat the process.
|
||||
a >>= 4;
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
final_sel = (sign & 0x04040404) | 0x03020100;
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
@@ -306,22 +311,29 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a)
|
||||
|
||||
uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel;
|
||||
|
||||
// ---- Lower 4 int4 values (even positions) ----
|
||||
// Extract dictionary indices: low 3 bits of each byte (values 0..7).
|
||||
uint32_t dict_sel = a & 0x07070707;
|
||||
uint32_t sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
// sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign.
|
||||
uint32_t sign = a >> 1;
|
||||
// Build final selector:
|
||||
// - bit 2 of each byte (0x04) selects negative vs positive table
|
||||
// - 0x03020100 selects byte lanes [0,1,2,3] in order
|
||||
final_sel = (sign & 0x04040404) | 0x03020100;
|
||||
|
||||
// Lookup positive and negative fp8 codes from the small register tables.
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
// Select per-lane between tmp_pos and tmp_neg using the sign-derived selector.
|
||||
tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel);
|
||||
|
||||
// ---- Upper 4 int4 values (odd positions) ----
|
||||
// Shift to bring the high-nibble int4s into place and repeat the process.
|
||||
a >>= 4;
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(final_sel)
|
||||
: "v"(sign), "v"(0x04040404), "v"(0x03020100));
|
||||
dict_sel = a & 0x07070707;
|
||||
sign = a >> 1;
|
||||
final_sel = (sign & 0x04040404) | 0x03020100;
|
||||
|
||||
tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel);
|
||||
tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel);
|
||||
|
||||
@@ -30,7 +30,7 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
{
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
return TailNumber::Full;
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -52,23 +52,27 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == TailNumber::Full)
|
||||
if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Full>{});
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
|
||||
if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
@@ -76,16 +80,8 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
// If execution reaches here, it's an invalid combination of arguments.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must "
|
||||
"be TailNumber::Full.");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must "
|
||||
"be TailNumber::Odd or TailNumber::Even.");
|
||||
}
|
||||
throw std::logic_error("Invalid TailNumber value: must be "
|
||||
"TailNumber::Odd or TailNumber::Even");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -588,7 +584,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// tail
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
|
||||
@@ -786,8 +786,8 @@ struct QuantGemmKernel
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bq_ptr,
|
||||
make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
|
||||
make_tuple(1, kargs.stride_BQ),
|
||||
make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B),
|
||||
make_tuple(kargs.stride_BQ, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}
|
||||
@@ -1030,9 +1030,9 @@ struct QuantGemmKernel
|
||||
using QuantGroupSize = remove_cvref_t<typename GemmPipeline::QuantGroupSize>;
|
||||
return make_tile_window(
|
||||
bq_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{},
|
||||
number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
|
||||
{0, i_n / QuantGroupSize::kN});
|
||||
make_tuple(number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{},
|
||||
number<TilePartitioner::KPerBlock / QuantGroupSize::kK>{}),
|
||||
{i_n / QuantGroupSize::kN, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -15,68 +15,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseAQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % BaseGemmPipelineAgBgCrCompV3<Problem>::PrefetchStages == 0)
|
||||
{
|
||||
return TailNumber::Even;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
}
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported tail number for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported tail number for this operation !!!");
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ToDo: Change the Pipeline to actual memory pipeline.
|
||||
template <typename Problem, typename Policy = GemmAQuantPipelineAgBgCrDefaultPolicy>
|
||||
struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem<Problem>
|
||||
struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
|
||||
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
@@ -14,74 +14,8 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseAQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == ck_tile::TailNumber::Full)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported tail number for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == ck_tile::TailNumber::Full)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported tail number for this operation !!!");
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename Policy = GemmAQuantPipelineAgBgCrDefaultPolicy>
|
||||
struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV3<Problem>
|
||||
struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
@@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
tile_distribution_encoding_pattern_bq<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
KPerBlockBQ,
|
||||
NPerBlockBQ,
|
||||
KPerBlockBQ,
|
||||
Problem::QuantGroupSize::kN>;
|
||||
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
@@ -20,68 +20,8 @@ namespace ck_tile {
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseBQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == ck_tile::TailNumber::Full)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported tail number for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == ck_tile::TailNumber::Full)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == ck_tile::TailNumber::Even)
|
||||
{
|
||||
return run_func(
|
||||
ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported tail number for this operation !!!");
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename Policy = GemmBQuantPipelineAgBgCrDefaultPolicy>
|
||||
struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV3<Problem>
|
||||
struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
@@ -318,8 +258,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV
|
||||
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
|
||||
BlockGemmShape::WarpTile::at(number<1>{}),
|
||||
0)
|
||||
: is_bq_col_major ? make_array(KPerBlockBQ, 0)
|
||||
: make_array(0, KPerBlockBQ);
|
||||
: is_bq_col_major ? make_array(0, KPerBlockBQ)
|
||||
: make_array(KPerBlockBQ, 0);
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
@@ -171,7 +171,7 @@ template <typename BlockGemmShape,
|
||||
index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t XPerQ,
|
||||
index_t YPerQ,
|
||||
bool PreshuffleQuant = false>
|
||||
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
|
||||
{
|
||||
@@ -231,39 +231,39 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(XPerQ < WarpGemm::kN)
|
||||
if constexpr(YPerQ < WarpGemm::kN)
|
||||
{
|
||||
// Case 1: Fine-grained - multiple quantization scales within a single warp
|
||||
constexpr index_t Y = YPerTile; // Full Y dimension of tile
|
||||
constexpr index_t YR = 1; // No Y replication needed
|
||||
constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim
|
||||
constexpr index_t X1 = NWarps; // Number of warps in N-dim
|
||||
constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp
|
||||
constexpr index_t XR = XPerQ; // Elements per quantization group
|
||||
constexpr index_t X = XPerTile; // Full X dimension of tile
|
||||
constexpr index_t XR = 1; // No Y replication needed
|
||||
constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim
|
||||
constexpr index_t Y1 = NWarps; // Number of warps in N-dim
|
||||
constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp
|
||||
constexpr index_t YR = YPerQ; // Elements per quantization group
|
||||
|
||||
static_assert(X0 * X1 * X2 == XPerTile,
|
||||
"X0, X1, X2 must cover the blocktile along X.");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile,
|
||||
"Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, YR, XR>,
|
||||
tuple<sequence<Y>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<0, 2>, sequence<0, 2, 0>>,
|
||||
tile_distribution_encoding<sequence<MWarps, XR, YR>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1, 0>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else if constexpr(XPerQ <= WarpGemm::kN * NWarps)
|
||||
else if constexpr(YPerQ <= WarpGemm::kN * NWarps)
|
||||
{
|
||||
// Case 2: Medium-grained - one quantization scale per warp
|
||||
constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor
|
||||
constexpr auto X1 = NWarps / XR; // Warps per unique scale
|
||||
constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension
|
||||
constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor
|
||||
constexpr auto Y1 = NWarps / YR; // Warps per unique scale
|
||||
constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<MWarps, XR, get_warp_size()>,
|
||||
tuple<sequence<YPerTile>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 2, 0>, sequence<0>>,
|
||||
tile_distribution_encoding<sequence<MWarps, YR, get_warp_size()>,
|
||||
tuple<sequence<Y0, Y1>, sequence<XPerTile>>,
|
||||
tuple<sequence<0, 1, 0>, sequence<0>>,
|
||||
tuple<sequence<0, 1, 1>, sequence<2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else // XPerQ > WarpGemm::kN * NWarps
|
||||
|
||||
@@ -280,7 +280,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
}
|
||||
// Prefill A0
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
@@ -338,7 +338,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
}
|
||||
|
||||
// Prefill A(2i+1)
|
||||
@@ -390,7 +390,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0});
|
||||
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
|
||||
}
|
||||
|
||||
// Prefill A(2i+2)
|
||||
|
||||
Reference in New Issue
Block a user