mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE MOE] add NT & preshuffle permute to cktile MOE (#3377)
* update coherence --------- Co-authored-by: Zzz9990 <Zzz9990>
This commit is contained in:
@@ -1413,10 +1413,10 @@ enum struct amd_buffer_coherence_enum
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
DEVICE_NT0 = 16,
|
||||
DEVICE_NT1 = 18,
|
||||
SYSTEM_NT0 = 17,
|
||||
SYSTEM_NT1 = 19,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
|
||||
@@ -1281,10 +1281,10 @@ enum struct amd_buffer_coherence_enum
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
DEVICE_NT0 = 16,
|
||||
DEVICE_NT1 = 18,
|
||||
SYSTEM_NT0 = 17,
|
||||
SYSTEM_NT1 = 19,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
|
||||
@@ -595,16 +595,44 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
if constexpr(!FlatmmPipeline::BPreShufflePermute)
|
||||
{
|
||||
index_t kFlatK =
|
||||
kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
FlatmmPipeline::BMemNTType>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t kFlatK = FlatmmPipeline::flatKPerWarp;
|
||||
index_t kFlatN0 = (kargs.N >> 4);
|
||||
index_t kFlatK0 = (kargs.K >> 7);
|
||||
|
||||
auto b_tensor_view_naive = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
FlatmmPipeline::BMemNTType>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatK0, kFlatN0 - kargs.n_padded_zeros / NPerXdl, kFlatK),
|
||||
make_tuple(kFlatK * (kFlatN0 - kargs.n_padded_zeros / NPerXdl), kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
return transform_tensor_view(
|
||||
b_tensor_view_naive,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kFlatN0 - kargs.n_padded_zeros / NPerXdl),
|
||||
make_merge_transform_v3_division_mod(make_tuple(kFlatK0, kFlatK))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
|
||||
@@ -24,6 +24,18 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr amd_buffer_coherence_enum
|
||||
GetBMemNTType(index_t M, index_t N, index_t K)
|
||||
{
|
||||
ck_tile::ignore = N;
|
||||
ck_tile::ignore = K;
|
||||
if(M <= 416)
|
||||
{
|
||||
return ck_tile::amd_buffer_coherence_enum::WAVE_NT1;
|
||||
}
|
||||
return ck_tile::amd_buffer_coherence_enum::coherence_default;
|
||||
}
|
||||
|
||||
template <bool DispatchHotloop = false, TailNumber tail_num, typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
|
||||
{
|
||||
|
||||
@@ -16,10 +16,12 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
|
||||
bool BPreShufflePermute_ = false,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
ADataType_,
|
||||
CDataType_,
|
||||
@@ -28,6 +30,8 @@ struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_,
|
||||
BMemNTType_,
|
||||
BPreShufflePermute_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShape_;
|
||||
@@ -183,6 +187,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr auto BMemNTType = Problem::BMemNTType;
|
||||
static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
{
|
||||
|
||||
@@ -115,6 +115,9 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
|
||||
|
||||
static constexpr auto BMemNTType = Problem::BMemNTType;
|
||||
static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -16,10 +16,12 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
|
||||
bool BPreShufflePermute_ = false,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
ADataType_,
|
||||
CDataType_,
|
||||
@@ -28,6 +30,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
Scheduler_,
|
||||
HasHotLoop_,
|
||||
TailNum_,
|
||||
BMemNTType_,
|
||||
BPreShufflePermute_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShape_;
|
||||
|
||||
@@ -316,10 +316,12 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
|
||||
bool BPreShufflePermute_ = false,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct FlatmmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
@@ -353,6 +355,9 @@ struct FlatmmPipelineProblem
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static constexpr auto BMemNTType = BMemNTType_;
|
||||
static constexpr bool BPreShufflePermute = BPreShufflePermute_;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
Reference in New Issue
Block a user