mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
update scale-preshuffle for MXF4
This commit is contained in:
@@ -11,7 +11,7 @@ struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
|
||||
@@ -97,17 +97,17 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::MixedPrecFlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::MixedPrecFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
@@ -134,7 +134,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::MixedPrecFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -282,10 +282,11 @@ float invoke_mixed_prec_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
// TODO (sizeof(BDataType) / 2)
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K / PackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
@@ -366,23 +367,26 @@ auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
constexpr int K_Pack = FlatmmConfig::K_Tile / FlatmmConfig::K_Warp_Tile / K_Lane;
|
||||
constexpr int K_Pack = 2; // fixed for mxfp4
|
||||
constexpr int N_Pack = 2; // fixed for mxfp4
|
||||
constexpr int GranularityK = 32; // fixed for mxfp4
|
||||
|
||||
static_assert(sizeof(T) * K_Pack * FlatmmConfig::N_Repeat <= 16, "inefficient pack policy");
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
k_ / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Tile,
|
||||
FlatmmConfig::N_Repeat,
|
||||
FlatmmConfig::N_Warp,
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
// return ck_tile::reference_permute(shfl_scale, {0, 3, 5, 2, 4, 1});
|
||||
return ck_tile::reference_permute(shfl_scale, {3, 5, 0, 2, 6, 1, 4});
|
||||
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
|
||||
}
|
||||
|
||||
#include "run_mixed_prec_flatmm.inc"
|
||||
|
||||
@@ -59,8 +59,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
|
||||
// ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-8.f, 8.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
@@ -166,7 +165,7 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
|
||||
const float rtol = 1e-3;
|
||||
const float rtol = 5e-3;
|
||||
const float atol = 1e-3;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
|
||||
@@ -13,11 +13,8 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename FlatmmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
int SupportArch = 0> // 0 means no arch restriction
|
||||
struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
using Underlying = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
|
||||
@@ -20,26 +20,32 @@ template <typename ADataType_,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct MixedPrecFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
ADataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_,
|
||||
ComputeDataType_>
|
||||
struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
ADataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShape_;
|
||||
|
||||
using QuantType = BDataType_;
|
||||
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
static constexpr index_t flatKPerWarp = 128;
|
||||
|
||||
static constexpr int MXF4ScaleGranularityK = 32;
|
||||
|
||||
static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
|
||||
static constexpr int ContinuousScaleNPerThread = 2; // it's fixed for fp4
|
||||
static constexpr int ContinuousScaleKPerThread = 2; // it's fixed for fp4
|
||||
static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = MixedPrecFlatmmPipelineAgBgCrPolicy>
|
||||
struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
template <typename Problem, typename PipelinePolicy = F16xMXF4FlatmmPipelineAgBgCrPolicy>
|
||||
struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
: FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
|
||||
{
|
||||
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
|
||||
@@ -117,6 +123,31 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
? DsReadPreload
|
||||
: MIterPerWarp * KIterPerWarp;
|
||||
|
||||
static constexpr int ContinuousKPerThread = Problem::ContinuousKPerThread;
|
||||
static constexpr int ContinuousScaleNPerThread = Problem::ContinuousScaleNPerThread;
|
||||
static constexpr int ContinuousScaleKPerThread = Problem::ContinuousScaleKPerThread;
|
||||
|
||||
static constexpr int MXFP4PackedSize = 2;
|
||||
|
||||
static constexpr int ScaleKFlatPerWarp =
|
||||
ContinuousScaleNPerThread * ContinuousScaleKPerThread * get_warp_size();
|
||||
|
||||
static constexpr int XDLK_PerThread =
|
||||
WarpTile::at(I2) / (get_warp_size() / WarpTile::at(I1)); // 8
|
||||
|
||||
static constexpr int XDL_PerWeightK = 4; // 4
|
||||
static constexpr int XDL_PerScaleK = XDL_PerWeightK * ContinuousScaleKPerThread; // 4
|
||||
static constexpr int XDL_PerScaleN = ContinuousScaleNPerThread; // 2
|
||||
static_assert(XDL_PerScaleK % XDL_PerWeightK == 0);
|
||||
static_assert(KIterPerWarp % XDL_PerScaleK == 0);
|
||||
static_assert(NIterPerWarp % XDL_PerScaleN == 0);
|
||||
|
||||
static constexpr int MXFP4KPerWarp = KIterPerWarp / XDL_PerWeightK;
|
||||
static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
|
||||
static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
|
||||
|
||||
static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
@@ -142,27 +173,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV1",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', WG::kM, WG::kN, WG::kK),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return PipelinePolicy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
{
|
||||
@@ -502,6 +515,15 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
auto A_XDL_TileDist = make_static_tile_distribution(typename WG::AWarpDstrEncoding{});
|
||||
auto A_Lds_TileDist =
|
||||
PipelinePolicy::template MakeFp16xF4_DS_WRITE_ATileDistribution<Problem>();
|
||||
auto A_Lds_Stride = WG::kK;
|
||||
|
||||
// auto A_XDL_TileDist = PipelinePolicy::template
|
||||
// MakeF16xF4_ALDS_TileDistribution<Problem>(); auto A_Lds_TileDist =
|
||||
// PipelinePolicy::template MakeADramTileDistribution<Problem>(); auto A_Lds_Stride = 8;
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -513,27 +535,26 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
A_Lds_TileDist);
|
||||
|
||||
auto a_copy_lds_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
A_Lds_TileDist);
|
||||
|
||||
auto A_Warp_Dist = PipelinePolicy::template MakeF16xF4_ADramDistribution<Problem>();
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
A_Warp_Dist);
|
||||
A_XDL_TileDist);
|
||||
|
||||
auto a_warp_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
A_Warp_Dist);
|
||||
A_XDL_TileDist);
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
@@ -545,23 +566,26 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
constexpr int KStridePerIter = 8;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
// auto weight_k_idx = kIter / number<XDL_PerWeightK>{};
|
||||
// auto weight_k_rank = kIter % number<XDL_PerWeightK>{};
|
||||
// move_tile_window(
|
||||
// a_warp_windows_ping(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter,
|
||||
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
// move_tile_window(
|
||||
// a_warp_windows_pong(mIter)(kIter),
|
||||
// {mIter * MPerBlockPerIter,
|
||||
// weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -570,12 +594,6 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
|
||||
constexpr int XDLPerLoadK = 4;
|
||||
constexpr int NRepeatPerScaleLoad = 2;
|
||||
|
||||
constexpr int QuantKPerWarp = KIterPerWarp / XDLPerLoadK;
|
||||
constexpr int QuantNPerWarp = NIterPerWarp / NRepeatPerScaleLoad;
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeFp4BFlatDramTileDistribution<Problem>();
|
||||
@@ -588,41 +606,37 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
constexpr int ScaleB_BlockK = 16 * 2 * 4;
|
||||
// flatKPerWarp * KIterPerWarp * NRepeatPerScaleLoad / XDLPerLoadK;
|
||||
|
||||
auto scale_b_flat_dram_window = make_tile_window(
|
||||
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<ScaleB_BlockK>{}),
|
||||
make_tuple(number<flatNPerWarp>{}, number<ScaleKFlatPerWarp>{}),
|
||||
scale_b_flat_window.get_window_origin(),
|
||||
scale_b_flat_distribution);
|
||||
|
||||
// pingpong buffer for B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), QuantKPerWarp>,
|
||||
statically_indexed_array<decltype(b_flat_dram_window), MXFP4KPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(scale_b_flat_dram_window), QuantKPerWarp>,
|
||||
QuantNPerWarp>
|
||||
scale_b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), MXFP4KPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), MXFP4KPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), QuantKPerWarp>,
|
||||
QuantNPerWarp>
|
||||
statically_indexed_array<decltype(scale_b_flat_dram_window), ScaleKPerWarp>,
|
||||
ScaleNPerWarp>
|
||||
scale_b_flat_dram_windows;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), ScaleKPerWarp>,
|
||||
ScaleNPerWarp>
|
||||
scale_b_warp_tensor_ping;
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), QuantKPerWarp>,
|
||||
QuantNPerWarp>
|
||||
statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), ScaleKPerWarp>,
|
||||
ScaleNPerWarp>
|
||||
scale_b_warp_tensor_pong;
|
||||
|
||||
// HEAD
|
||||
@@ -633,35 +647,42 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
|
||||
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
|
||||
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
|
||||
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
|
||||
scale_b_flat_dram_window;
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
|
||||
{scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp});
|
||||
scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
|
||||
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
|
||||
}
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
|
||||
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
|
||||
|
||||
// A_Lds_TileDist may differ with ADramTileDistribution
|
||||
auto a_block_tile_transformed = make_static_distributed_tensor<ComputeType>(A_Lds_TileDist);
|
||||
|
||||
a_block_tile_transformed.get_thread_buffer() =
|
||||
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
@@ -689,64 +710,44 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
auto dequant_B = typename WG::BWarpTensor{};
|
||||
|
||||
auto perm_scale = [&](auto lane_scale, auto xdl_k_idx) {
|
||||
#if defined(__gfx942__)
|
||||
lane_scale = __builtin_amdgcn_ds_bpermute(((get_lane_id() % 16) + 16 * xdl_k_idx) * 4,
|
||||
lane_scale);
|
||||
return lane_scale;
|
||||
#endif
|
||||
auto v2scale = __builtin_amdgcn_permlane32_swap(lane_scale, lane_scale, 0, 0);
|
||||
if constexpr(xdl_k_idx < 2)
|
||||
{
|
||||
lane_scale = v2scale[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
lane_scale = v2scale[1];
|
||||
}
|
||||
auto dequant_mxfp4 = [&](const auto& quant_weight_tensor,
|
||||
const auto& scale_tensor,
|
||||
auto xdl_nIter,
|
||||
auto xdl_kIter) {
|
||||
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
|
||||
|
||||
v2scale = __builtin_amdgcn_permlane16_swap(lane_scale, lane_scale, 0, 0);
|
||||
if constexpr(xdl_k_idx % 2 == 0)
|
||||
{
|
||||
lane_scale = v2scale[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
lane_scale = v2scale[1];
|
||||
}
|
||||
return lane_scale;
|
||||
};
|
||||
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
|
||||
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
|
||||
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
|
||||
|
||||
auto deq_fn = [&](const auto& quant_weight_tensor,
|
||||
const auto& scale_tensor,
|
||||
auto xdl_nIter,
|
||||
auto xdl_kIter) {
|
||||
auto b_idx_k = xdl_kIter % number<XDLPerLoadK>{};
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<NRepeatPerScaleLoad>{};
|
||||
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_idx_n];
|
||||
|
||||
auto use_scale = scale;
|
||||
use_scale.data = perm_scale(scale.data, b_idx_k);
|
||||
|
||||
if constexpr(xdl_nIter == 0)
|
||||
if(blockIdx.x == 0 && threadIdx.x < 64 && get_lane_id() % 16 == 0)
|
||||
{
|
||||
printf("laneid = %2u xdl-k=%2d use-scale = "
|
||||
"%.2f\n",
|
||||
threadIdx.x,
|
||||
int(xdl_kIter),
|
||||
float(use_scale));
|
||||
}
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
|
||||
dequant_B.get_thread_buffer().template set_as<fp16x2_t>(
|
||||
number<i>{},
|
||||
pk_fp4_to_fp16x2(
|
||||
quant_weight_tensor.get_thread_buffer()[b_idx_k * ScalarCnt / 2 + i],
|
||||
static_cast<float>(use_scale)));
|
||||
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
|
||||
|
||||
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
|
||||
if constexpr(std::is_same_v<ComputeType, half_t>)
|
||||
{
|
||||
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else if constexpr(std::is_same_v<ComputeType, bf16_t>)
|
||||
{
|
||||
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
|
||||
}
|
||||
};
|
||||
using ComputeV2Type =
|
||||
std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
|
||||
|
||||
static_for<0, PackedCnt, 1>{}([&](auto i) {
|
||||
dequant_B.get_thread_buffer().template set_as<ComputeV2Type>(
|
||||
i,
|
||||
pk_mxfp4_to_compute_v2(
|
||||
quant_weight_tensor.get_thread_buffer()[quant_idx_k * PackedCnt + i],
|
||||
static_cast<float>(scale)));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -755,34 +756,43 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
|
||||
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
|
||||
scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
|
||||
{scale_n_iter * NFlatPerBlockPerIter,
|
||||
scale_k_iter * ScaleKFlatPerWarp});
|
||||
|
||||
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
|
||||
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
|
||||
}
|
||||
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
a_block_tile_transformed.get_thread_buffer() =
|
||||
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
@@ -801,11 +811,11 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
@@ -835,8 +845,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
|
||||
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
@@ -849,34 +859,42 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
|
||||
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
|
||||
scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
|
||||
{scale_n_iter * NFlatPerBlockPerIter,
|
||||
scale_k_iter * ScaleKFlatPerWarp});
|
||||
|
||||
scale_b_warp_tensor_ping(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
|
||||
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
|
||||
}
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
a_block_tile_transformed.get_thread_buffer() =
|
||||
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
@@ -894,11 +912,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
@@ -927,8 +946,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleB_BlockK});
|
||||
move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
@@ -945,34 +964,43 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % NRepeatPerScaleLoad == 0)
|
||||
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
|
||||
{
|
||||
auto dequant_n_iter = nIter / number<QuantNPerWarp>{};
|
||||
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
|
||||
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
|
||||
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter) = scale_b_flat_dram_window;
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
|
||||
scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
scale_b_flat_dram_windows(dequant_n_iter)(kIter),
|
||||
{dequant_n_iter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
|
||||
{scale_n_iter * NFlatPerBlockPerIter,
|
||||
scale_k_iter * ScaleKFlatPerWarp});
|
||||
|
||||
scale_b_warp_tensor_pong(dequant_n_iter)(kIter) =
|
||||
load_tile(scale_b_flat_dram_windows(dequant_n_iter)(kIter));
|
||||
scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
|
||||
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
|
||||
}
|
||||
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * ContinuousScaleNPerThread * NFlatPerBlockPerIter +
|
||||
packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
a_block_tile_transformed.get_thread_buffer() =
|
||||
tile_elementwise_in(a_element_func, a_block_tile).get_thread_buffer();
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
@@ -986,11 +1014,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
@@ -1039,11 +1068,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
dequant_mxfp4(b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
@@ -1084,11 +1114,11 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<NRepeatPerScaleLoad>{})(
|
||||
kIter / number<XDLPerLoadK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
dequant_mxfp4(b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
|
||||
@@ -7,76 +7,110 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct MixedPrecFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
|
||||
}
|
||||
}
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
static constexpr index_t N_Pack = 2; // it's fixed for fp4
|
||||
static constexpr index_t K_Pack = 2; // it's fixed for fp4
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
{
|
||||
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) / 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ADramDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp16xF4_ADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<4>,
|
||||
tuple<sequence<16>, sequence<4, 4, 8>>,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp16xF4_DS_WRITE_ATileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
// unmerge K0 to K16_i x K4_1 x K4_2
|
||||
// then exchange the order of K4_1 and K4_2
|
||||
constexpr index_t XDL_PerKBLoad = 4;
|
||||
constexpr index_t K128_Cnt = K0 / XDL_PerKBLoad / XDL_PerKBLoad;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K128_Cnt, XDL_PerKBLoad, XDL_PerKBLoad, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2, 2, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0, 2, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALDS_TileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M0 = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
|
||||
|
||||
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
|
||||
constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
|
||||
constexpr int K0 = K_Lane; // 4
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Repeat>,
|
||||
tuple<sequence<M0>, sequence<K0, XDL_PerThreadK, K2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
@@ -86,37 +120,34 @@ struct MixedPrecFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = 32;
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Pack>, // second
|
||||
// direction
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -130,111 +161,25 @@ struct MixedPrecFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t N_Repeat = TileShape::kN / TileShape::WarpTile::at(I1) / N_Warp;
|
||||
constexpr index_t N_Pack = N_Repeat;
|
||||
|
||||
constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
|
||||
constexpr index_t KBPerLoad = XDLPerBlock * N_Pack;
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t K_Pack = XDLPerBlock / K_Lane;
|
||||
|
||||
// constexpr index_t RepeatScale = TileShape::WarpTile::at(I2) / ;
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk>, // second direction
|
||||
sequence<K_Lane, 16, N_Pack * K_Pack>>, // first
|
||||
// direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<1>, sequence<2, 2>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = kMPerBlock / M1;
|
||||
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size >= (K2 * M0))
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * M0);
|
||||
constexpr index_t K0 = kBlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
|
||||
{
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
|
||||
typename Problem::ADataType,
|
||||
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename
|
||||
// Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
|
||||
tile_distribution_encoding<
|
||||
sequence<>, // ?
|
||||
tuple<sequence<NWavePerBlk>, // second direction
|
||||
sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
|
||||
// direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<1>, sequence<2, 2>>, // which direction
|
||||
tuple<sequence<0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user