add mixed_prec fp16xfp4

This commit is contained in:
Feng Shijie
2025-08-08 20:19:16 +00:00
parent 3dea10a277
commit f788d3d629
9 changed files with 252 additions and 123 deletions

View File

@@ -39,6 +39,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr index_t NumDTensor = DsDataType::size();
static constexpr auto I0 = number<0>();
@@ -89,16 +91,15 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
}
}();
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
return make_naive_tensor_view<address_space_enum::global>(b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<32>{},
number<1>{});
}();
const auto& ds_tensor_view = generate_tuple(
@@ -307,7 +308,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0))
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
@@ -346,8 +348,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / QuantPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS

View File

@@ -371,8 +371,39 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
sequence<1>>{});
}
template <typename Problem, int PackSize = 1>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution(number<PackSize> = {})
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ADramDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<4>,
tuple<sequence<16>, sequence<4, 4, 8>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<2>,
sequence<2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
@@ -380,7 +411,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>() / PackSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
@@ -407,6 +438,42 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = 32;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{

View File

@@ -29,7 +29,12 @@ struct MixedPrecFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
TailNum_,
ComputeDataType_>
{
using BlockGemmShape = BlockGemmShape_;
using QuantType = BDataType_;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr index_t flatKPerWarp = 128;
};
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
@@ -68,8 +73,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
@@ -168,15 +173,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK;
index_t inst_order[NIterPerWarp * 10];
#pragma unroll
for(int idx = 0; idx < NIterPerWarp * 10; idx++)
{
inst_order[idx] = 0;
}
_Pragma("unroll") for(int idx = 0; idx < NIterPerWarp * 10; idx++) { inst_order[idx] = 0; }
index_t index = 0;
#pragma unroll
for(int j = 0; j < max_data_inst; j++)
_Pragma("unroll") for(int j = 0; j < max_data_inst; j++)
{
if(dswrite_perM > j)
{
@@ -195,9 +195,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
}
}
// Schedule IGLP
#pragma unroll
for(int j = 0; j < mfma_perM_perK; j++)
// Schedule IGLP
_Pragma("unroll") for(int j = 0; j < mfma_perM_perK; j++)
{
index_t inst_idx = 0;
if(j == 0)
@@ -211,8 +210,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#pragma unroll
for(int r = 0; r < round_data_inst; r++)
_Pragma("unroll") for(int r = 0; r < round_data_inst; r++)
{
if(r % 2 == 0)
{
@@ -325,11 +323,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// 0 M7N2: 63 - - 8 -
// 0 M7N3: 64 4 - - -
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
_Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
_Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
@@ -390,11 +386,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
_Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
_Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
@@ -444,11 +438,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
_Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
_Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
@@ -524,18 +516,19 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto A_Warp_Dist = PipelinePolicy::template MakeF16xF4_ADramDistribution<Problem>();
// ping-pong window for A LDS
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
A_Warp_Dist);
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
A_Warp_Dist);
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
@@ -547,12 +540,14 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
MIterPerWarp>
a_warp_windows_pong;
constexpr int KStridePerIter = 8;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
});
});
@@ -561,7 +556,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
});
});
@@ -570,9 +565,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
constexpr int XDLPerLoadK = 4;
constexpr int QuantKPerWarp = KIterPerWarp / XDLPerLoadK;
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>(number<2>{});
PipelinePolicy::template MakeFp4BFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
@@ -582,17 +580,17 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// pingpong buffer for B
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
statically_indexed_array<decltype(b_flat_dram_window), QuantKPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
@@ -604,7 +602,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
@@ -616,20 +614,6 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// Prefill A0
// if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
// {
// auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
// PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
// shuffle_tile(a_shuffle_tmp, a_block_tile);
// const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
// store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
// }
// else
// {
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func,
// a_block_tile));
// }
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
__builtin_amdgcn_sched_barrier(0);
@@ -657,12 +641,23 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
});
__builtin_amdgcn_sched_barrier(0);
auto dequant_B = typename WG::BWarpTensor{};
auto deq_fn = [&](auto& quant_weight_tensor, auto sub_idx) {
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
dequant_B.get_thread_buffer().template set_as<fp16x2_t>(
number<i>{},
fp16x2_t(quant_weight_tensor.get_thread_buffer()[sub_idx * ScalarCnt / 2 + i]));
});
};
// MAIN LOOP
index_t iCounter = 0; // (num_loop - 1) / 2;
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
{
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
@@ -694,10 +689,11 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_ping(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -737,7 +733,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
@@ -768,10 +764,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_pong(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -815,7 +811,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
@@ -842,10 +838,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_ping(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -892,10 +888,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_pong(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -934,10 +930,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_ping(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(