diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index ba9201135c..8830adfdd9 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1413,10 +1413,10 @@ enum struct amd_buffer_coherence_enum WAVE_NT1 = 2, GROUP_NT0 = 1, GROUP_NT1 = 3, - DEVICE_NT0 = 8, - DEVICE_NT1 = 10, - SYSTEM_NT0 = 9, - SYSTEM_NT1 = 11, + DEVICE_NT0 = 16, + DEVICE_NT1 = 18, + SYSTEM_NT0 = 17, + SYSTEM_NT1 = 19, }; template ( - b_flat_ptr, - make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); + if constexpr(!FlatmmPipeline::BPreShufflePermute) + { + index_t kFlatK = + kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK) + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + return make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + index_t kFlatK = FlatmmPipeline::flatKPerWarp; + index_t kFlatN0 = (kargs.N >> 4); + index_t kFlatK0 = (kargs.K >> 7); + + auto b_tensor_view_naive = make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatK0, kFlatN0 - kargs.n_padded_zeros / NPerXdl, kFlatK), + make_tuple(kFlatK * (kFlatN0 - kargs.n_padded_zeros / NPerXdl), kFlatK, 1), + number{}, + number<1>{}); + return transform_tensor_view( + b_tensor_view_naive, + make_tuple( + make_pass_through_transform(kFlatN0 - kargs.n_padded_zeros / NPerXdl), + make_merge_transform_v3_division_mod(make_tuple(kFlatK0, kFlatK))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } }(); // TODO: enable vector write for C in ColMajor diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 79b36adec4..e4f186dead 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -24,6 +24,18 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; } + CK_TILE_HOST static constexpr amd_buffer_coherence_enum + GetBMemNTType(index_t M, index_t N, index_t K) + { + ck_tile::ignore = N; + ck_tile::ignore = K; + if(M <= 416) + { + return ck_tile::amd_buffer_coherence_enum::WAVE_NT1; + } + return ck_tile::amd_buffer_coherence_enum::coherence_default; + } + template CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop) { diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 8ec23b7570..74d82b8949 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -16,10 +16,12 @@ template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default, + bool BPreShufflePermute_ = false, + typename ComputeDataType_ = ADataType_> struct F16xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem { using BlockGemmShape = BlockGemmShape_; @@ -183,6 +187,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. static constexpr bool DoubleSmemBuffer = false; + static constexpr auto BMemNTType = Problem::BMemNTType; + static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute; + CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) { diff --git a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp index fe6d3ec830..5681726afe 100644 --- a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp @@ -115,6 +115,9 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1 static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp; static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp; + static constexpr auto BMemNTType = Problem::BMemNTType; + static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 87ae7f57d8..69e9441ae5 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -16,10 +16,12 @@ template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default, + bool BPreShufflePermute_ = false, + typename ComputeDataType_ = ADataType_> struct MXFlatmmPipelineProblem : FlatmmPipelineProblem { using BlockGemmShape = BlockGemmShape_; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index e35f4ce70d..46c1f69b12 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -316,10 +316,12 @@ template + GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, + bool HasHotLoop_ = true, + TailNumber TailNum_ = TailNumber::Full, + amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default, + bool BPreShufflePermute_ = false, + typename ComputeDataType_ = ADataType_> struct FlatmmPipelineProblem { using Traits = remove_cvref_t; @@ -353,6 +355,9 @@ struct FlatmmPipelineProblem static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; + static constexpr auto BMemNTType = BMemNTType_; + static constexpr bool BPreShufflePermute = BPreShufflePermute_; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off