[CK_TILE] Add FP8xF4 Flatmm (#3401)

* Refactor policy

* fix a bank conflict

* Enable mixed mx flatmm

* Update
This commit is contained in:
Yi DING
2025-12-17 10:01:48 +08:00
committed by GitHub
parent 3dfa794fab
commit 57e1e4a848
9 changed files with 231 additions and 223 deletions

View File

@@ -8,6 +8,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline
### Changed
@@ -36,6 +37,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added pooling kernel in CK_TILE
* Added top-k sigmoid kernel in CK_TILE
* Added the blockscale 2D support for CK_TILE GEMM.
* Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types
### Changed

View File

@@ -148,7 +148,7 @@ auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "32", "m dimension")
.insert("n", "128", "n dimension")
.insert("n", "512", "n dimension")
.insert("k", "256", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
@@ -308,6 +308,28 @@ int run_mx_flatmm_example(int argc, char* argv[])
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp8xfp4")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
MXf8f4_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp4xfp8")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXf4f8_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else
{
throw std::runtime_error("Unsupported data_type!");

View File

@@ -76,6 +76,69 @@ struct MXfp8_FlatmmConfig16
static constexpr bool TiledMMAPermuteN = false;
};
struct MXf8f4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXf4f8_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,

View File

@@ -6,16 +6,20 @@ function(mx_flatmm_instance_generate FILE_LIST)
set(A_LAYOUT ROW)
set(B_LAYOUT COL)
set(C_LAYOUT ROW)
set(FLATMM_CONFIG_FP4 "MXfp4_FlatmmConfig16")
set(FLATMM_CONFIG_FP8 "MXfp8_FlatmmConfig16")
set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16")
set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16")
set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16")
# foreach(PERSISTENT false true)
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
foreach(PERSISTENT false)
foreach(DATA_TYPE FP4 FP8)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP8xFP4 FP4xFP8)
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
set(A_DATA_TYPE ${DATA_TYPE})
set(B_DATA_TYPE ${DATA_TYPE})
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)
list(GET DATA_TYPE_AB 1 B_DATA_TYPE)
foreach(SPLIT_K false true)
foreach(HAS_HOT_LOOP false true)
foreach(TAIL_NUMBER ODD EVEN)

View File

@@ -414,12 +414,7 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
auto a_block_window_with_distr =
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
a_block_window.get_window_origin(),
MXFlatmmPipeline::GetADramTileDistribution());
const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window_with_distr,
const auto& c_block_tile = MXFlatmmPipeline{}(a_block_window,
b_flat_block_window,
scale_a_block_window,
scale_b_block_window,

View File

@@ -23,7 +23,7 @@ template <typename ADataType_,
bool BPreShufflePermute_ = false,
typename ComputeDataType_ = ADataType_>
struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
@@ -132,8 +132,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
static constexpr index_t KXdlPack = Problem::KXdlPack;
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType);
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType);
static constexpr index_t AK1 = 16 /*dwordx4*/ * APackedSize / sizeof(ADataType);
static constexpr index_t BK1 = 16 /*dwordx4*/ * BPackedSize / sizeof(BDataType);
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
@@ -470,11 +470,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
// __builtin_amdgcn_sched_barrier(0);
}
CK_TILE_HOST_DEVICE static constexpr auto GetADramTileDistribution()
{
return PipelinePolicy::template MakeADramTileDistribution<Problem>();
}
template <typename... Args>
CK_TILE_DEVICE auto operator()(Args&&... args) const
{
@@ -684,7 +679,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
statically_indexed_array<decltype(load_tile(a_warp_window_pong)), m_preload> a_warp_tensor;
// preload A00,A10... from lds
s_waitcnt_barrier</*vmcnt*/ dswrite_num_perK>();
s_waitcnt_barrier</*vmcnt*/ Bload_num + ScaleAload_num + ScaleBload_num>();
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;

View File

@@ -7,6 +7,8 @@
namespace ck_tile {
namespace detail {
template <typename Problem>
struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
@@ -14,27 +16,47 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
static constexpr auto I2 = number<2>{};
static constexpr index_t kDramLoadPackBytes = 128;
static constexpr index_t DWORDx4 = 16;
static constexpr int MXdlPack = 2;
static constexpr int NXdlPack = 2;
static constexpr int KXdlPack = 2;
template <typename Problem>
static inline constexpr auto wg_attr_num_access =
std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
? WGAttrNumAccessEnum::Single
: WGAttrNumAccessEnum::Double;
private:
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
using TileShape = typename Problem::BlockGemmShape;
using BlockWarps = typename TileShape::BlockWarps;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t WaveNum = BlockSize / WaveSize;
static constexpr index_t MPerBlock = TileShape::kM;
static constexpr index_t NPerBlock = TileShape::kN;
static constexpr index_t KPerBlock = TileShape::kK;
static constexpr index_t MWarps = BlockWarps::at(I0);
static constexpr index_t NWarps = BlockWarps::at(I1);
static_assert(WaveNum == MWarps * NWarps, "Block warps do not match block size");
static constexpr index_t MPerXdl = TileShape::WarpTile::at(I0);
static constexpr index_t NPerXdl = TileShape::WarpTile::at(I1);
static constexpr index_t KPerXdl = TileShape::WarpTile::at(I2);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static constexpr index_t K_Lane = get_warp_size() / 16; // 4
static constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
public:
static constexpr index_t AK1 = DWORDx4 * APackedSize;
static constexpr index_t BK1 = DWORDx4 * BPackedSize;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(
sizeof(ADataType) * numeric_traits<BDataType>::PackedSize ==
sizeof(BDataType) * numeric_traits<ADataType>::PackedSize,
"sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!");
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher< //
ADataType,
@@ -43,10 +65,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access<Problem>>;
Problem::TransposeC>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
ADataType,
BDataType,
@@ -56,28 +75,20 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
}
template <typename Problem, typename TensorView>
template <typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
const auto& naive_desc = naive_view.get_tensor_descriptor();
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
const auto rows = naive_desc.get_length(number<0>{});
const auto cols = naive_desc.get_length(number<1>{});
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
const index_t K0 = cols / (K1 * K2);
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t K2 = AK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
const index_t K0 = cols / (K1 * K2);
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
const index_t M0 = rows / M1;
@@ -106,25 +117,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
TensorView::DstInMemOp>{naive_view.buf_, desc};
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K2 = AK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t M2 = get_warp_size() / K1; // 8
constexpr index_t M1 = BlockSize / get_warp_size(); // 4
constexpr index_t M2 = WaveSize / K1; // 8
constexpr index_t M1 = BlockSize / WaveSize; // 4
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
@@ -139,28 +139,16 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t K2 = AK1; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
constexpr index_t M2 = get_warp_size() / K1 / M3; // 2
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
constexpr index_t M2 = WaveSize / K1 / M3; // 2
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
@@ -168,14 +156,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<M0>{},
number<M1>{},
number<K0>{},
number<M1>{},
number<M2>{},
number<M3>{},
number<K1>{},
number<K2>{}),
make_tuple(number<M1*(K0 * (M2 * M3 * K1 * K2) + (K0 - 1) * Pad)>{},
number<K0*(M2 * M3 * K1 * K2) + (K0 - 1) * Pad>{},
make_tuple(number<K0*(M1 * (M2 * M3 * K1 * K2) + (M1 - 1) * Pad)>{},
number<M1*(M2 * M3 * K1 * K2) + (M1 - 1) * Pad>{},
number<M2 * M3 * K1 * K2 + Pad>{},
number<M3 * K1 * K2>{},
number<K1 * K2>{},
@@ -187,8 +175,8 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(M1),
make_pass_through_transform(K0),
make_pass_through_transform(M1),
make_pass_through_transform(M2),
make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
make_pass_through_transform(number<K2>{})),
@@ -210,103 +198,71 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}),
make_tuple(sequence<0, 2, 3, 4>{}, sequence<1, 5, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_permuted;
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16
constexpr int K_Lane = 64 / M_Lane; // 4
constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr int K1 = K_Thread / num_access_v; // 16
return make_static_tile_distribution(
std::conditional_t<
num_access_v == 1,
tile_distribution_encoding<
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
if constexpr(K_Thread == AK1)
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<NWarps>,
tuple<sequence<MWarps, MXdlPack, MPerXdl>, sequence<K_Lane, AK1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>,
tile_distribution_encoding< //
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<num_access_v, K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>,
sequence<0, 2>>>{});
sequence<1>>{});
else
return make_static_tile_distribution(tile_distribution_encoding< //
sequence<NWarps>,
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
sequence<K_Thread / AK1, K_Lane, AK1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>,
sequence<0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t BPack = numeric_traits<BDataType>::PackedSize;
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t K1 = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t K0 = KWavePerBlk;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
constexpr index_t kKPerThread = 32;
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr index_t K2 = kKPerThread / num_access_v;
return make_static_tile_distribution(
std::conditional_t< //
num_access_v == 1,
if constexpr(BK1 == K_Thread)
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<K0, K1, K2 / BPack>>, // 1 64 32
tuple<sequence<NWarps, NXdlPack>, // 4 2
sequence<K0, K1, BK1 / BPackedSize>>, // 1 64 32
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 0>, sequence<1>>,
sequence<2>,
sequence<2>>,
sequence<2>>{});
else
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<num_access_v, K0, K1, K2 / BPack>>, // 2 1 64 16
tuple<sequence<NWarps, NXdlPack>, // 4 2
sequence<K_Thread / BK1, K0, K1, BK1 / BPackedSize>>, // 2 1 64 16
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 1>, sequence<2>>,
sequence<2, 2>,
sequence<0, 3>>>{});
sequence<0, 3>>{});
}
template <typename Problem, typename WindowTmp>
template <typename WindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp)
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr auto kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1);
constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp;
constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp;
@@ -314,7 +270,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
static_assert(std::decay_t<decltype(window_tmp)>::get_num_of_dimension() == 2);
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile;
constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile;
auto&& byte_tensor_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
@@ -331,39 +287,25 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
byte_tensor_view,
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
{origin_tmp[0], origin_tmp[1] / BPackedSize},
MakeMX_BFlatBytesDramTileDistribution<Problem>());
MakeMX_BFlatBytesDramTileDistribution());
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0);
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
constexpr index_t K_Lanes = 64 / M_Lanes;
// Y dimension (M) decomposition
constexpr index_t Y2 = M_Lanes;
constexpr index_t Y1 = M_Warps;
constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
constexpr index_t Y1 = MWarps;
constexpr index_t Y0 = MPerBlock / (MXdlPack * Y1 * Y2);
// X dimension (K) decomposition
constexpr index_t X0 = K_Lanes;
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N_Warps>, // repeat N_warps
tile_distribution_encoding<sequence<NWarps>, // repeat NWarps
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<0, 2>>,
@@ -371,36 +313,22 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1);
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
constexpr index_t K_Lanes = 64 / N_Lanes;
// Y dimension (M) decomposition
constexpr index_t Y2 = N_Lanes;
constexpr index_t Y1 = N_Warps;
constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
constexpr index_t Y1 = NWarps;
constexpr index_t Y0 = NPerBlock / (NXdlPack * Y1 * Y2);
// X dimension (K) decomposition
constexpr index_t X0 = K_Lanes;
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
return make_static_tile_distribution(
tile_distribution_encoding<sequence<M_Warps>, // ?
tile_distribution_encoding<sequence<MWarps>, // ?
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
@@ -408,20 +336,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{});
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
constexpr index_t M_Lane = TileShape::WarpTile::at(I0);
constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{});
constexpr index_t MWavePerBlk = M_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N_Wrap>, // ?
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
tile_distribution_encoding<sequence<NWarps>, // ?
tuple<sequence<MWarps, MPerXdl>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
@@ -430,20 +349,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{});
constexpr index_t NWavePerBlk = N_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<M_Wrap>, // ?
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
tile_distribution_encoding<sequence<MWarps>, // ?
tuple<sequence<NWarps, NPerXdl>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
@@ -452,20 +362,41 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor<Problem>().get_element_space_size() /
return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() /
APackedSize;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return GetSmemSizeA<Problem>();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); }
};
} // namespace detail
struct MXFlatmmPipelineAgBgCrPolicy
{
#define FORWARD_METHOD_(method) \
template <typename Problem, typename... Args> \
CK_TILE_HOST_DEVICE static constexpr auto method(Args&&... args) \
{ \
return detail::MXFlatmmPipelineAgBgCrPolicy<Problem>::method(std::forward<Args>(args)...); \
}
FORWARD_METHOD_(GetBlockFlatmm);
FORWARD_METHOD_(MakeMX_AAsyncLoadDramDescriptor);
FORWARD_METHOD_(MakeMX_ADramTileDistribution);
FORWARD_METHOD_(MakeMX_ALdsBlockDescriptor);
FORWARD_METHOD_(MakeMX_ALDS_TileDistribution);
FORWARD_METHOD_(MakeMX_BFlatBytesDramTileDistribution);
FORWARD_METHOD_(MakeMX_BFlatBytesDramWindow);
FORWARD_METHOD_(MakeMX_ScaleA_DramTileDistribution);
FORWARD_METHOD_(MakeMX_ScaleB_DramTileDistribution);
FORWARD_METHOD_(MakeMX_ScaleA_FlatDramTileDistribution);
FORWARD_METHOD_(MakeMX_ScaleB_FlatDramTileDistribution);
FORWARD_METHOD_(GetSmemSizeA);
FORWARD_METHOD_(GetSmemSize);
#undef FORWARD_METHOD_
};
} // namespace ck_tile

View File

@@ -306,10 +306,9 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<pk_fp4_t, pk_fp4_t>,
AttrNumAccess>>;
template <typename A, typename B, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<A, B>, AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< //
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,

View File

@@ -116,15 +116,12 @@ template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Ty
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
// scale mfma based f8f6f4
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<I>; };
template<typename A, typename B, WGAttrNumAccessEnum I>
struct Dispatcher<A, B, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4<A, B, I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<I>; };
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<I>; };
template<> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp4<>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };