mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
compile pass
This commit is contained in:
@@ -242,10 +242,13 @@ struct MoeFlatmmKernel
|
||||
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
|
||||
|
||||
// MXF4_Pipeline only has the of scale B and granularityK is 32
|
||||
static constexpr bool AQUANT_Pipeline = std::is_same_v<ADataType, ck_tile::bf8_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::pk_fp4_t>;
|
||||
static constexpr bool AQUANT_Pipeline = std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<ADataType, fp8_t> ||
|
||||
std::is_same_v<ADataType, pk_fp4_t>;
|
||||
static constexpr bool BMXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
|
||||
|
||||
static_assert(AQUANT_Pipeline);
|
||||
static_assert(BMXFP4_Pipeline);
|
||||
static constexpr bool MXF8F6F4MFMA =
|
||||
#ifdef __gfx950__
|
||||
AQUANT_Pipeline && BMXFP4_Pipeline;
|
||||
@@ -663,7 +666,7 @@ struct MoeFlatmmKernel
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(std::is_same_v<ScaleType, e8m0_t>)
|
||||
else if constexpr(std::is_same_v<AScaleType, e8m0_t>)
|
||||
{
|
||||
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
|
||||
@@ -675,11 +678,11 @@ struct MoeFlatmmKernel
|
||||
const auto scale_a_desc = transform_tensor_descriptor(
|
||||
scale_a_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
|
||||
reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -694,16 +697,16 @@ struct MoeFlatmmKernel
|
||||
index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl);
|
||||
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl);
|
||||
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
|
||||
make_tuple(scale_n_packs, scale_k_packs, KThreadPerXdl, NThreadPerXdl));
|
||||
const auto scale_b_desc = transform_tensor_descriptor(
|
||||
scale_b_navie_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(make_merge_transform(make_tuple(scale_n_packs, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
reinterpret_cast<const int32_t*>(scale_n.ptr), scale_b_desc);
|
||||
|
||||
}
|
||||
else
|
||||
@@ -827,7 +830,7 @@ struct MoeFlatmmKernel
|
||||
make_tile_window(views.at(I3),
|
||||
make_tuple(number<TilePartitioner::MPerBlock / M_Pack>{},
|
||||
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
|
||||
{i_m / M_Pack, 0});
|
||||
{coord_m / M_Pack, 0});
|
||||
|
||||
// constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
|
||||
constexpr int XDLPerLoadScaleB =
|
||||
|
||||
@@ -2493,15 +2493,15 @@ template <typename ADataType_,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct F8xMXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
ADataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_,
|
||||
ComputeDataType_>
|
||||
struct F8xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShape_;
|
||||
|
||||
@@ -2519,7 +2519,7 @@ struct F8xMXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = MXF8FlatmmPipelineAgBgCrPolicy>
|
||||
template <typename Problem, typename PipelinePolicy = F8xMXF4FlatmmPipelineAgBgCrPolicy>
|
||||
struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
|
||||
{
|
||||
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
|
||||
@@ -2600,8 +2600,8 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Pr
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
|
||||
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; // 16 / 1 = 16
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize; // 16 / 1 * 2 = 32
|
||||
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
@@ -2612,8 +2612,8 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Pr
|
||||
|
||||
static constexpr index_t mfma_per_wg = 1; // 950 only
|
||||
|
||||
static constexpr index_t dsread_per_wg = WG::kM * WG::kK / AK1 / WaveSize;
|
||||
static_assert((WG::kM * WG::kK) % (AK1 * WaveSize) == 0);
|
||||
static constexpr index_t dsread_per_wg = WG::kM * WG::kK / AK1 / WaveSize; // 16 * 128 / 16 / 64 = 2
|
||||
static_assert((WG::kM * WG::kK) % (AK1 * WaveSize) == 0); // 16 * 128 % 16 * 64
|
||||
|
||||
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / NWarp;
|
||||
@@ -2982,7 +2982,7 @@ struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Pr
|
||||
a_copy_dram_window_tmp.get_bottom_tensor_view()),
|
||||
a_copy_dram_window_tmp.get_window_lengths(),
|
||||
a_copy_dram_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
|
||||
@@ -536,19 +536,17 @@ struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
template <typename Problem>
|
||||
static inline constexpr auto wg_attr_num_access =
|
||||
std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
|
||||
? WGAttrNumAccessEnum::Single
|
||||
: WGAttrNumAccessEnum::Double;
|
||||
WGAttrNumAccessEnum::Single;
|
||||
// std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
|
||||
// ? WGAttrNumAccessEnum::Single
|
||||
// : WGAttrNumAccessEnum::Double;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
static_assert(
|
||||
sizeof(ADataType) * numeric_traits<BDataType>::PackedSize ==
|
||||
sizeof(BDataType) * numeric_traits<ADataType>::PackedSize,
|
||||
"sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!");
|
||||
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
@@ -634,7 +632,8 @@ struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
|
||||
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
constexpr index_t K2 = MPerBlock == 16 ? GetSmemPackA<Problem>() * APackedSize/ 4:
|
||||
GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
|
||||
Reference in New Issue
Block a user