mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
some updates
This commit is contained in:
@@ -44,9 +44,9 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
static constexpr int KXdlPack = 2;
|
||||
static constexpr int MXdlPack = remove_cvref_t<typename FlatmmPipeline::MXdlPack>;
|
||||
static constexpr int NXdlPack = remove_cvref_t<typename FlatmmPipeline::NXdlPack>;
|
||||
static constexpr int KXdlPack = remove_cvref_t<typename FlatmmPipeline::KXdlPack>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -464,7 +464,8 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
FlatmmPipeline::GetADramTileDistribution());
|
||||
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
|
||||
b_flat_block_window,
|
||||
scale_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
@@ -261,10 +261,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
@@ -301,7 +303,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
@@ -463,12 +465,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
|
||||
typename Problem::ADataType,
|
||||
|
||||
@@ -32,7 +32,7 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShape_;
|
||||
|
||||
using QuantType = BDataType_;
|
||||
// using QuantType = BDataType_;
|
||||
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
@@ -51,7 +51,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::QuantType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BuantType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
@@ -91,9 +91,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
// static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -118,33 +118,45 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::packed_size;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::packed_size;
|
||||
|
||||
static constexpr index_t MXdlPack = Problem::MXdlPack;
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
|
||||
static constexpr index_t MIterScalePerWarp = MIterPerWarp / MXdlPack;
|
||||
static constexpr index_t NIterScalePerWarp = NIterPerWarp / NXdlPack;
|
||||
static constexpr index_t KIterScalePerWarp = KIterPerWarp / KXdlPack;
|
||||
|
||||
static constexpr int MXFP4PackedSize = 2;
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * MXFP4PackedSize;
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
|
||||
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
: MIterPerWarp * KIterPerWarp;
|
||||
|
||||
static constexpr int ContinuousKPerThread = Problem::ContinuousKPerThread;
|
||||
static constexpr int ContinuousScaleNPerThread = Problem::ContinuousScaleNPerThread;
|
||||
static constexpr int ContinuousScaleKPerThread = Problem::ContinuousScaleKPerThread;
|
||||
// static constexpr int ContinuousKPerThread = Problem::ContinuousKPerThread;
|
||||
// static constexpr int ContinuousScaleNPerThread = Problem::ContinuousScaleNPerThread;
|
||||
// static constexpr int ContinuousScaleKPerThread = Problem::ContinuousScaleKPerThread;
|
||||
|
||||
static constexpr int ScaleKFlatPerWarp =
|
||||
ContinuousScaleNPerThread * ContinuousScaleKPerThread * get_warp_size();
|
||||
// static constexpr int ScaleKFlatPerWarp =
|
||||
// ContinuousScaleNPerThread * ContinuousScaleKPerThread * get_warp_size();
|
||||
|
||||
static constexpr int XDLK_PerThread =
|
||||
WarpTile::at(I2) / (get_warp_size() / WarpTile::at(I1)); // 8
|
||||
// static constexpr int XDLK_PerThread =
|
||||
// WarpTile::at(I2) / (get_warp_size() / WarpTile::at(I1)); // 8
|
||||
|
||||
static constexpr int XDL_PerWeightK = 4; // 4
|
||||
static constexpr int XDL_PerScaleK = XDL_PerWeightK * ContinuousScaleKPerThread; // 4
|
||||
static constexpr int XDL_PerScaleN = ContinuousScaleNPerThread; // 2
|
||||
static_assert(XDL_PerScaleK % XDL_PerWeightK == 0);
|
||||
static_assert(KIterPerWarp % XDL_PerScaleK == 0);
|
||||
static_assert(NIterPerWarp % XDL_PerScaleN == 0);
|
||||
// static constexpr int XDL_PerWeightK = 4; // 4
|
||||
// static constexpr int XDL_PerScaleK = XDL_PerWeightK * ContinuousScaleKPerThread; // 4
|
||||
// static constexpr int XDL_PerScaleN = ContinuousScaleNPerThread; // 2
|
||||
// static_assert(XDL_PerScaleK % XDL_PerWeightK == 0);
|
||||
// static_assert(KIterPerWarp % XDL_PerScaleK == 0);
|
||||
// static_assert(NIterPerWarp % XDL_PerScaleN == 0);
|
||||
|
||||
static constexpr int MXFP4KPerWarp = KIterPerWarp / XDL_PerWeightK;
|
||||
static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
|
||||
static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
|
||||
// static constexpr int MXFP4KPerWarp = KIterPerWarp / XDL_PerWeightK;
|
||||
// static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
|
||||
// static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
|
||||
|
||||
static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
|
||||
|
||||
@@ -487,11 +499,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename DequantBFlatWindow>
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const DequantBFlatWindow& scale_b_flat_window,
|
||||
const ScaleADramBlockWindowTmp& scale_a_dram_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_flat_window,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
@@ -507,6 +521,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
@@ -522,7 +537,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeF16xF4_ALdsBlockDescriptor<Problem>();
|
||||
PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block_ping =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
@@ -545,12 +560,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
PipelinePolicy::template MakeF16xF4_ALDS_TileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
PipelinePolicy::template MakeF16xF4_ALDS_TileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
@@ -562,22 +577,40 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
auto A_Lds_Stride = 8;
|
||||
// auto A_Lds_Stride = 8;
|
||||
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
// a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
// auto weight_k_idx = kIter / number<XDL_PerWeightK>{};
|
||||
// auto weight_k_rank = kIter % number<XDL_PerWeightK>{};
|
||||
// move_tile_window(
|
||||
// a_warp_windows_ping(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter,
|
||||
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
// move_tile_window(
|
||||
// a_warp_windows_pong(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter,
|
||||
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
// });
|
||||
// });
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
auto weight_k_idx = kIter / number<XDL_PerWeightK>{};
|
||||
auto weight_k_rank = kIter % number<XDL_PerWeightK>{};
|
||||
auto packed_m_idx = mIter / number<MXdlPack>{};
|
||||
auto packed_m_rank = mIter % number<MXdlPack>{};
|
||||
|
||||
move_tile_window(
|
||||
a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter,
|
||||
weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
move_tile_window(
|
||||
a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter,
|
||||
weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -588,9 +621,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeFp4BFlatDramTileDistribution<Problem>();
|
||||
auto scale_b_flat_distribution =
|
||||
PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution<Problem>();
|
||||
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>();
|
||||
// auto scale_b_flat_distribution =
|
||||
// PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution<Problem>();
|
||||
|
||||
auto b_flat_dram_window = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
@@ -598,11 +631,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
auto scale_b_flat_dram_window = make_tile_window(
|
||||
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<ScaleKFlatPerWarp>{}),
|
||||
scale_b_flat_window.get_window_origin(),
|
||||
scale_b_flat_distribution);
|
||||
// auto scale_b_flat_dram_window = make_tile_window(
|
||||
// scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
// make_tuple(number<flatNPerWarp>{}, number<ScaleKFlatPerWarp>{}),
|
||||
// scale_b_flat_window.get_window_origin(),
|
||||
// scale_b_flat_distribution);
|
||||
|
||||
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
@@ -615,7 +648,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
// pingpong buffer for B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), MXFP4KPerWarp>,
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_Buffer, MXFP4KPerWarp>,
|
||||
@@ -625,18 +658,29 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(scale_b_flat_dram_window), ScaleKPerWarp>,
|
||||
ScaleNPerWarp>
|
||||
scale_b_flat_dram_windows;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), ScaleKPerWarp>,
|
||||
ScaleNPerWarp>
|
||||
scale_b_warp_tensor_ping;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), ScaleKPerWarp>,
|
||||
ScaleNPerWarp>
|
||||
scale_b_warp_tensor_pong;
|
||||
// statically_indexed_array<
|
||||
// statically_indexed_array<decltype(scale_b_flat_dram_window), ScaleKPerWarp>,
|
||||
// ScaleNPerWarp>
|
||||
// scale_b_flat_dram_windows;
|
||||
// statically_indexed_array<
|
||||
// statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)),
|
||||
// ScaleKPerWarp>, ScaleNPerWarp> scale_b_warp_tensor_ping;
|
||||
// statically_indexed_array<
|
||||
// statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)),
|
||||
// ScaleKPerWarp>, ScaleNPerWarp> scale_b_warp_tensor_pong;
|
||||
|
||||
// pingpong buffer for Scale A and Scale B
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
scale_a_dram_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock / MXdlPack>{}, number<kKPerBlock / KXdlPack>{}),
|
||||
scale_a_draw_window.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ScaleA_DramTileDistribution<Problem>());
|
||||
|
||||
auto scale_b_dram_winodow = make_tile_window(
|
||||
scale_b_dram_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock / NXdlPack>{}, number<kKPerBlock / KXdlPack>{}),
|
||||
scale_b_dram_window.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution<Problem>());
|
||||
|
||||
// HEAD
|
||||
// Prefetch A0
|
||||
@@ -1191,10 +1235,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename DequantBFlatWindow>
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const DequantBFlatWindow& scale_b_flat_window,
|
||||
const ScaleADramblockWindowTmp& scale_a_flat_window_tmp,
|
||||
const ScaleBDramblockWindowTmp& scale_b_flat_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
@@ -1203,7 +1249,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
scale_b_flat_window,
|
||||
scale_a_flat_window_tmp,
|
||||
scale_b_flat_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
|
||||
@@ -13,54 +13,70 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
static constexpr index_t N_Pack = 2; // it's fixed for fp4
|
||||
static constexpr index_t K_Pack = 2; // it's fixed for fp4
|
||||
// static constexpr index_t KBPerLoad = 32;
|
||||
// static constexpr index_t N_Pack = 2; // it's fixed for fp4
|
||||
// static constexpr index_t K_Pack = 2; // it's fixed for fp4
|
||||
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
static constexpr int KXdlPack = 2;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALdsBlockDescriptor()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
number<MPerBlock>{}, number<KPerBlock / KPack>{})), // xor on M
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
|
||||
|
||||
return a_lds_block_desc;
|
||||
// constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_0,
|
||||
// make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
|
||||
// number<ContiguousThreadsCntInDS_READ_16B>{})),
|
||||
// make_pass_through_transform(number<KPack>{})),
|
||||
// make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
// make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
// constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
// a_lds_block_desc_permuted,
|
||||
// make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
// make_merge_transform_v3_division_mod(
|
||||
// make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc_permuted;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp16xF4_ADramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
@@ -92,7 +108,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALDS_TileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
@@ -101,26 +117,26 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
constexpr index_t MXdlPack = Problm::MXdlPack;
|
||||
constexpr int NWaves = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
|
||||
|
||||
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
|
||||
constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
|
||||
constexpr int K0 = K_Lane; // 4
|
||||
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 8
|
||||
constexpr int K0 = K_Lane; // 4
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Repeat>,
|
||||
tuple<sequence<M0>, sequence<K0, XDL_PerThreadK, K2>>,
|
||||
tile_distribution_encoding<sequence<NWaves>,
|
||||
tuple<sequence<M0, MXdlPack>, sequence<K0, K1>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -140,7 +156,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Pack>, // second
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // second
|
||||
// direction
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
@@ -152,8 +168,40 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
|
||||
// {
|
||||
// using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
// constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t WaveSize = get_warp_size();
|
||||
// constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
// constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
// constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
|
||||
// constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
// constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
|
||||
// constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<
|
||||
// sequence<>, // ?
|
||||
// tuple<sequence<NWavePerBlk>, // second direction
|
||||
// sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
|
||||
// // direction
|
||||
// // wave in blk, // thd in wave
|
||||
// // <M, K> // <M, K>
|
||||
// tuple<sequence<1>, sequence<2, 2>>, // which direction
|
||||
// tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// // <repeat, vec_load>
|
||||
// sequence<2>,
|
||||
// sequence<2>>{});
|
||||
// }
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
@@ -161,27 +209,69 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr index_t kMPerBlock = tileShape::BlockTile::at(I0);
|
||||
|
||||
constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
|
||||
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
|
||||
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
static_assert(num_warps == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
|
||||
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
|
||||
constexpr index_t K_Lanes = 64 / M_Lanes;
|
||||
|
||||
// Y dimension (M) decomposition
|
||||
static constexpr index_t Y2 = M_Lanes;
|
||||
static constexpr index_t Y1 = M_Warps;
|
||||
static constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
|
||||
|
||||
// X dimension (K) decomposition
|
||||
static constexpr index_t X0 = K_Lanes;
|
||||
static constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk>, // second direction
|
||||
sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
|
||||
// direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<1>, sequence<2, 2>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
tile_distribution_encoding<sequence<N_Warps>, // repeat N_warps
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t kNPerBlock = tileShape::BlockTile::at(I1);
|
||||
|
||||
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
|
||||
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
|
||||
|
||||
static_assert(num_warps == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
|
||||
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
|
||||
constexpr index_t K_Lanes = 64 / N_Lanes;
|
||||
|
||||
// Y dimension (M) decomposition
|
||||
static constexpr index_t Y2 = N_Lanes;
|
||||
static constexpr index_t Y1 = N_Warps;
|
||||
static constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
|
||||
|
||||
// X dimension (K) decomposition
|
||||
static constexpr index_t X0 = K_Lanes;
|
||||
static constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<M_Warps>, // ?
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user