[CK_TILE MOE] add NT & preshuffle permute to cktile MOE (#3377)

* update coherence
---------

Co-authored-by: Zzz9990 <Zzz9990>
This commit is contained in:
Zzz9990
2025-12-10 10:03:28 +08:00
committed by GitHub
parent 934ba1208a
commit 1aa93ef551
8 changed files with 88 additions and 29 deletions

View File

@@ -1413,10 +1413,10 @@ enum struct amd_buffer_coherence_enum
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
DEVICE_NT0 = 16,
DEVICE_NT1 = 18,
SYSTEM_NT0 = 17,
SYSTEM_NT1 = 19,
};
template <index_t N,

View File

@@ -1281,10 +1281,10 @@ enum struct amd_buffer_coherence_enum
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
DEVICE_NT0 = 16,
DEVICE_NT1 = 18,
SYSTEM_NT0 = 17,
SYSTEM_NT1 = 19,
};
template <index_t N,

View File

@@ -595,16 +595,44 @@ struct MoeFlatmmKernel
}
}();
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
if constexpr(!FlatmmPipeline::BPreShufflePermute)
{
index_t kFlatK =
kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
index_t kFlatN = kargs.N * kargs.K / kFlatK;
return make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
FlatmmPipeline::BMemNTType>(
b_flat_ptr,
make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
}
else
{
index_t kFlatK = FlatmmPipeline::flatKPerWarp;
index_t kFlatN0 = (kargs.N >> 4);
index_t kFlatK0 = (kargs.K >> 7);
auto b_tensor_view_naive = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
FlatmmPipeline::BMemNTType>(
b_flat_ptr,
make_tuple(kFlatK0, kFlatN0 - kargs.n_padded_zeros / NPerXdl, kFlatK),
make_tuple(kFlatK * (kFlatN0 - kargs.n_padded_zeros / NPerXdl), kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
return transform_tensor_view(
b_tensor_view_naive,
make_tuple(
make_pass_through_transform(kFlatN0 - kargs.n_padded_zeros / NPerXdl),
make_merge_transform_v3_division_mod(make_tuple(kFlatK0, kFlatK))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
}();
// TODO: enable vector write for C in ColMajor

View File

@@ -24,6 +24,18 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
CK_TILE_HOST static constexpr amd_buffer_coherence_enum
GetBMemNTType(index_t M, index_t N, index_t K)
{
ck_tile::ignore = N;
ck_tile::ignore = K;
if(M <= 416)
{
return ck_tile::amd_buffer_coherence_enum::WAVE_NT1;
}
return ck_tile::amd_buffer_coherence_enum::coherence_default;
}
template <bool DispatchHotloop = false, TailNumber tail_num, typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
{

View File

@@ -16,10 +16,12 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
bool BPreShufflePermute_ = false,
typename ComputeDataType_ = ADataType_>
struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
ADataType_,
CDataType_,
@@ -28,6 +30,8 @@ struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
Scheduler_,
HasHotLoop_,
TailNum_,
BMemNTType_,
BPreShufflePermute_,
ComputeDataType_>
{
using BlockGemmShape = BlockGemmShape_;
@@ -183,6 +187,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
static constexpr auto BMemNTType = Problem::BMemNTType;
static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute;
CK_TILE_HOST_DEVICE static constexpr auto
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
{

View File

@@ -115,6 +115,9 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
static constexpr auto BMemNTType = Problem::BMemNTType;
static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -16,10 +16,12 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
bool BPreShufflePermute_ = false,
typename ComputeDataType_ = ADataType_>
struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
ADataType_,
CDataType_,
@@ -28,6 +30,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
Scheduler_,
HasHotLoop_,
TailNum_,
BMemNTType_,
BPreShufflePermute_,
ComputeDataType_>
{
using BlockGemmShape = BlockGemmShape_;

View File

@@ -316,10 +316,12 @@ template <typename ADataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
bool BPreShufflePermute_ = false,
typename ComputeDataType_ = ADataType_>
struct FlatmmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
@@ -353,6 +355,9 @@ struct FlatmmPipelineProblem
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr auto BMemNTType = BMemNTType_;
static constexpr bool BPreShufflePermute = BPreShufflePermute_;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off