compile pass

This commit is contained in:
Zzz9990
2025-12-09 07:25:26 -06:00
parent 98ddeebdc0
commit 2f5eb26839
3 changed files with 36 additions and 34 deletions

View File

@@ -242,10 +242,13 @@ struct MoeFlatmmKernel
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
// MXF4_Pipeline only has the of scale B and granularityK is 32
static constexpr bool AQUANT_Pipeline = std::is_same_v<ADataType, ck_tile::bf8_t> ||
std::is_same_v<ADataType, ck_tile::fp8_t> ||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>;
static constexpr bool AQUANT_Pipeline = std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, pk_fp4_t>;
static constexpr bool BMXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
static_assert(AQUANT_Pipeline);
static_assert(BMXFP4_Pipeline);
static constexpr bool MXF8F6F4MFMA =
#ifdef __gfx950__
AQUANT_Pipeline && BMXFP4_Pipeline;
@@ -663,7 +666,7 @@ struct MoeFlatmmKernel
number<8>{},
number<1>{});
}
else if constexpr(std::is_same_v<ScaleType, e8m0_t>)
else if constexpr(std::is_same_v<AScaleType, e8m0_t>)
{
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
@@ -675,11 +678,11 @@ struct MoeFlatmmKernel
const auto scale_a_desc = transform_tensor_descriptor(
scale_a_naive_desc,
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
}
}();
@@ -694,16 +697,16 @@ struct MoeFlatmmKernel
index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl);
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
make_tuple(scale_n_packs, scale_k_packs, KThreadPerXdl, NThreadPerXdl));
const auto scale_b_desc = transform_tensor_descriptor(
scale_b_navie_desc,
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
make_tuple(make_merge_transform(make_tuple(scale_n_packs, NThreadPerXdl)),
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
reinterpret_cast<const int32_t*>(scale_n.ptr), scale_b_desc);
}
else
@@ -827,7 +830,7 @@ struct MoeFlatmmKernel
make_tile_window(views.at(I3),
make_tuple(number<TilePartitioner::MPerBlock / M_Pack>{},
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
{i_m / M_Pack, 0});
{coord_m / M_Pack, 0});
// constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
constexpr int XDLPerLoadScaleB =

View File

@@ -2493,15 +2493,15 @@ template <typename ADataType_,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
struct F8xMXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
ADataType_,
CDataType_,
BlockGemmShape_,
Traits_,
Scheduler_,
HasHotLoop_,
TailNum_,
ComputeDataType_>
struct F8xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
Traits_,
Scheduler_,
HasHotLoop_,
TailNum_,
ComputeDataType_>
{
using BlockGemmShape = BlockGemmShape_;
@@ -2519,7 +2519,7 @@ struct F8xMXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
};
template <typename Problem, typename PipelinePolicy = MXF8FlatmmPipelineAgBgCrPolicy>
template <typename Problem, typename PipelinePolicy = F8xMXF4FlatmmPipelineAgBgCrPolicy>
struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
{
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
@@ -2600,8 +2600,8 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Pr
static constexpr index_t KXdlPack = Problem::KXdlPack;
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; // 16 / 1 = 16
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize; // 16 / 1 * 2 = 32
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
? DsReadPreload
@@ -2612,8 +2612,8 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Pr
static constexpr index_t mfma_per_wg = 1; // 950 only
static constexpr index_t dsread_per_wg = WG::kM * WG::kK / AK1 / WaveSize;
static_assert((WG::kM * WG::kK) % (AK1 * WaveSize) == 0);
static constexpr index_t dsread_per_wg = WG::kM * WG::kK / AK1 / WaveSize; // 16 * 128 / 16 / 64 = 2
static_assert((WG::kM * WG::kK) % (AK1 * WaveSize) == 0); // 16 * 128 % 16 * 64
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
static constexpr index_t dswrite_num_perK = dsread_num_perK / NWarp;
@@ -2982,7 +2982,7 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Pr
a_copy_dram_window_tmp.get_bottom_tensor_view()),
a_copy_dram_window_tmp.get_window_lengths(),
a_copy_dram_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
PipelinePolicy::template MakeADramTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);

View File

@@ -536,19 +536,17 @@ struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
static inline constexpr auto wg_attr_num_access =
std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
? WGAttrNumAccessEnum::Single
: WGAttrNumAccessEnum::Double;
WGAttrNumAccessEnum::Single;
// std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
// ? WGAttrNumAccessEnum::Single
// : WGAttrNumAccessEnum::Double;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(
sizeof(ADataType) * numeric_traits<BDataType>::PackedSize ==
sizeof(BDataType) * numeric_traits<ADataType>::PackedSize,
"sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!");
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher< //
@@ -634,7 +632,8 @@ struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K2 = MPerBlock == 16 ? GetSmemPackA<Problem>() * APackedSize/ 4:
GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256