From c6135f6abe3291693f28e76ac75f0c5e21750967 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Wed, 10 Sep 2025 05:03:08 -0500 Subject: [PATCH] updates some fixes. --- example/ck_tile/18_flatmm/CMakeLists.txt | 2 + .../mixed_prec/mixed_prec_flatmm.cpp | 41 +- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 48 +- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp | 2 +- .../18_flatmm/mxgemm/run_mx_flatmm.inc | 53 +- .../ops/epilogue/chuffle_epilogue_feat.hpp | 760 ------------------ .../ops/epilogue/cshuffle_epilogue.hpp | 226 +++--- .../kernel/mixed_prec_flatmm_kernel.hpp | 14 +- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 19 +- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 10 +- 10 files changed, 230 insertions(+), 945 deletions(-) delete mode 100644 include/ck_tile/ops/epilogue/chuffle_epilogue_feat.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index bd462028f3..64a2bcfbfe 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -1,5 +1,6 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) add_executable(tile_example_mixed_prec_flatmm EXCLUDE_FROM_ALL mixed_prec/mixed_prec_flatmm.cpp) +add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) @@ -13,3 +14,4 @@ list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-nrvo) #list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1") target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp index 22d8f34efd..3744f589f4 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp @@ -160,6 +160,16 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& << std::endl; } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; + + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + if(s.flush_cache_) { std::cout << "Flushing cache..." << std::endl; @@ -174,34 +184,25 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( + rotating_mem_ptr = std::make_unique>( kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + rotating_mem_ptr->Print(); - auto run_flush_cache = [&]() { - // flush icache + preprocess = [&]() { ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + rotating_mem_ptr->Next(); + clear_gemm_output(); }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + preprocess = clear_gemm_output; } - return ave_time; + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 4721c17b80..ff773ad2a2 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -154,12 +154,22 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n" << "Shape: " << CodegenFlatmmShape::GetName() << "\n" << "problem: " << CodegenPipelineProblem::GetName() << "\n" - << "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n" + << "pipeline: " << CodegenMXFlatmmPipeline::GetName() << "\n" << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; + + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + if(s.flush_cache_) { std::cout << "Flushing cache..." << std::endl; @@ -174,34 +184,25 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( + rotating_mem_ptr = std::make_unique>( kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + rotating_mem_ptr->Print(); - auto run_flush_cache = [&]() { - // flush icache + preprocess = [&]() { ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + rotating_mem_ptr->Next(); + clear_gemm_output(); }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); } else { - // ave_time = - // ck_tile::launch_kernel(s, - // ck_tile::make_kernel( - // Kernel{}, grids, blocks, 0, kargs)); + preprocess = clear_gemm_output; } - return ave_time; + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { @@ -265,7 +266,6 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, stride_B, {}, stride_C, - {}, scale_a, scale_b}; @@ -391,7 +391,7 @@ auto preShuffleScale(const ck_tile::HostTensor& scale) return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4}); } -#include "run_mx_prec_flatmm.inc" +#include "run_mx_flatmm.inc" template int run_mx_flatmm_example(int argc, char* argv[]) @@ -463,7 +463,7 @@ int main(int argc, char* argv[]) } else if(warp_tile == 1) { - thow std::runtime_error("Only support MFMA_16x16x128 now!"); + throw std::runtime_error("Only support MFMA_16x16x128 now!"); } else { diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp index 7fb344f46a..b47d3a95ab 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -12,4 +12,4 @@ #include "ck_tile/ops/flatmm.hpp" #include "ck_tile/ops/gemm.hpp" -#include "mxfp4_flatmm.hpp" \ No newline at end of file +#include "mxfp4_flatmm.hpp" diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index 1edadb5b57..9bd2cc054e 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -19,12 +19,12 @@ int run_mx_flatmm_with_layouts(int argc, if(!result) return -1; - using ADataType = PrecActType; - using BDataType = PrecWeightType; - using CDataType = CDataType; + using ADataType = PrecActType; + using BDataType = PrecWeightType; + // using CDataType = CDataType; using AccDataType = float; - using ScaleType = ck_tile::e8m0_t; + using ScaleDataType = ck_tile::e8m0_t; constexpr int ScaleGranularityM = 1; constexpr int ScaleGranularityN = 1; @@ -52,11 +52,12 @@ int run_mx_flatmm_with_layouts(int argc, auto scale_stride_B = ck_tile::get_default_stride( K / ScaleGranularityK, N / ScaleGranularityN, 0, is_row_major(b_layout)); - if(K % DequantGranularityK != 0) + if(K % ScaleGranularityK != 0) { - thow std::runtime_error("wrong! K must be multiple of ScaleGranularityK."); + throw std::runtime_error("wrong! K must be multiple of ScaleGranularityK."); } - if(K % ck_tile::packed_size_v != 0 || K % ck_tile::packed_size_v != 0) + if(K % ck_tile::numeric_traits::PackedSize != 0 || + K % ck_tile::numeric_traits::PackedSize != 0) { throw std::runtime_error("wrong! K must be multiple of packed size."); } @@ -68,31 +69,31 @@ int run_mx_flatmm_with_layouts(int argc, ck_tile::HostTensor c_rslt_host( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - ck_tile::HostTensor scale_a(ck_tile::host_tensor_descriptor( + ck_tile::HostTensor scale_a(ck_tile::host_tensor_descriptor( M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout))); - ck_tile::HostTensor scale_b(ck_tile::host_tensor_descriptor( + ck_tile::HostTensor scale_b(ck_tile::host_tensor_descriptor( K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout))); if(init_method == 0) { ck_tile::FillUniformDistribution{0.0f, 1.0f}(a_host); ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_a); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_b); + ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_b); } else if(init_method == 1) { ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); - ck_tile::FillUniformDistribution{1.f, 1.f}(scale_a); - ck_tile::FillUniformDistribution{1.f, 1.f}(scale_b); + ck_tile::FillUniformDistribution{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution{1.f, 1.f}(scale_b); } else if(init_method == 2) { ck_tile::FillUniformDistribution{0.0f, 1.0f}(a_host); ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_a); - ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_b); + ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_b); } #if 0 @@ -186,8 +187,8 @@ int run_mx_flatmm_with_layouts(int argc, ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); preShuffleWeight(b_origin_host.begin(), b_shuffled_host.begin(), N, K); - ck_tile::HostTensor scale_a_shuffled = preShuffleScale(scale_a); - ck_tile::HostTensor scale_b_shuffled = preShuffleScale(scale_b); + ck_tile::HostTensor scale_a_shuffled = preShuffleScale(scale_a); + ck_tile::HostTensor scale_b_shuffled = preShuffleScale(scale_b); ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes()); @@ -197,15 +198,15 @@ int run_mx_flatmm_with_layouts(int argc, ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); a_dev_buf.ToDevice(a_host.data()); - b_shuffle_dev_buf.ToDevice(b_shuffled_host.data()); + b_shuffled_dev_buf.ToDevice(b_shuffled_host.data()); c_rslt_host.SetZero(); scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); - auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer{ - static_cast(scale_a_dev_buf.GetDeviceBuffer()), M / DequantGranularityM}; - auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer{ - static_cast(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN}; + auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer{ + static_cast(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM}; + auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer{ + static_cast(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN}; invoke_mx_flatmm(a_dev_buf, - b_shuffle_dev_buf, + b_shuffled_dev_buf, c_dev_buf, M, N, @@ -244,11 +245,7 @@ int run_mx_flatmm_with_layouts(int argc, c_m_n_host_ref.SetZero(); ck_tile::reference_mx_gemm( - a_host.data(), - b_origin_host.data(), - c_m_n_host_ref.data(), - scale_a.data(), - scale_b.data()); + a_host, b_origin_host, c_m_n_host_ref, scale_a, scale_b); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); diff --git a/include/ck_tile/ops/epilogue/chuffle_epilogue_feat.hpp b/include/ck_tile/ops/epilogue/chuffle_epilogue_feat.hpp deleted file mode 100644 index 3c9a0d7a8b..0000000000 --- a/include/ck_tile/ops/epilogue/chuffle_epilogue_feat.hpp +++ /dev/null @@ -1,760 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" - -namespace ck_tile { - -template // The number of continuous xdl_output per warp -struct CShuffleEpilogueProblem -{ - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; - static constexpr index_t MWave = MWave_; - static constexpr index_t NWave = NWave_; - static constexpr index_t MPerXdl = MPerXdl_; - static constexpr index_t NPerXdl = NPerXdl_; - static constexpr index_t KPerXdl = KPerXdl_; - static constexpr index_t isCTransposed = isCTransposed_; - static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; - static constexpr bool FixedVectorSize = FixedVectorSize_; - static constexpr index_t VectorSizeC = VectorSizeC_; - static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; - static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; - static constexpr index_t kNumWaveGroups = kNumWaveGroups_; - static constexpr index_t NumDTensor = DsDataType::size(); - - static_assert(NumDTensor == DsLayout::size(), - "The size of DsDataType and DsLayout should be the same"); -}; - -template -struct CShuffleEpilogue -{ - using Problem = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - // Used for weight-only quantization kernel, B would be dequantized to the same data type as A - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kMPerBlock = Problem::kMPerBlock; - static constexpr index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t MWave = Problem::MWave; - static constexpr index_t NWave = Problem::NWave; - static constexpr index_t MPerXdl = Problem::MPerXdl; - static constexpr index_t NPerXdl = Problem::NPerXdl; - static constexpr index_t KPerXdl = Problem::KPerXdl; - static constexpr index_t isCTransposed = Problem::isCTransposed; - static constexpr bool FixedVectorSize = Problem::FixedVectorSize; - static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; - static constexpr index_t VectorSizeC = Problem::VectorSizeC; - static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; - static constexpr index_t MPerIteration = MPerXdl * MWave; - static constexpr index_t NPerIteration = NPerXdl * NWave; - static constexpr index_t NumDTensor = Problem::NumDTensor; - static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); - static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); - - static_assert(NumDTensor == DsLayout::size(), - "The size of DsDataType and DsLayout should be the same"); - /** - * @brief Get the vector store size for C tensor. - * - * @note The vector store size for output C tensor would depend on multiple factors - * like its data layout and warp gemm C transposition. In general it would - * be the number of consecutive elements in contiguous C dimension hold by - * single thread. - * - * @return The vector store size for C tensor. - */ - CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() - { - if constexpr(FixedVectorSize) - { - return VectorSizeC; - } - constexpr index_t max_vector_size = 16; - if constexpr(std::is_same_v) - { - return std::min(static_cast(NPerIteration), - static_cast(max_vector_size / sizeof(ODataType))); - } - else if constexpr(std::is_same_v) - { - return std::min(static_cast(MPerIteration), - static_cast(max_vector_size / sizeof(ODataType))); - } - else - { - static_assert(false, "Unsupported ELayout!"); - } - } - - /** - * @brief Get the vector store size for Di tensor. - * - * @return The vector store size for Di tensor. - */ - template - CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number index) - { - constexpr index_t max_vector_size = 16; - using DiDataType = remove_cvref_t>; - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return std::min(static_cast(NPerIteration), - static_cast(max_vector_size / sizeof(DiDataType))); - } - else if constexpr(std::is_same_v) - { - return std::min(static_cast(MPerIteration), - static_cast(max_vector_size / sizeof(DiDataType))); - } - else - { - static_assert(false, "Unsupported DLayout!"); - } - return max_vector_size / sizeof(DiDataType); - } - /** - * @brief Shuffle tile configuration parameters - * - * @details These parameters control the number of XDL tiles processed per wave in each shuffle - * iteration: - * - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave - * - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave - */ - static constexpr auto shuffle_tile_tuple = [] { - constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size(); - if constexpr(elem_per_thread >= GetVectorSizeC()) - { - return std::make_tuple(1, 1); - } - else - { - constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; - if constexpr(std::is_same_v) - { - static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && - (kMPerBlock % num_xdl_shuffles == 0), - "kMPerBlock must be divisible by MPerXdl*MWave and " - "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1); - } - else - { - static_assert((kNPerBlock % (NPerXdl * NWave) == 0) && - (kNPerBlock % num_xdl_shuffles == 0), - "kNPerBlock must be divisible by NPerXdl*NWave and " - "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave))); - } - } - }(); - static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple); - static constexpr index_t NumNXdlPerWavePerShuffle = - max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple)); - - static_assert(NumNXdlPerWavePerShuffle % BlockedXDLN_PerWarp == 0); - - static constexpr auto MNPerIterationShuffle = [] { - constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle; - constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle; - if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0) - return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave); - else - return std::make_tuple(m_val, n_val); - }(); - static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); - static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); - - using WG = WarpGemmMfmaDispatcher; - - using CWarpDstr = typename WG::CWarpDstr; - using CWarpTensor = typename WG::CWarpTensor; - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() - { - // N is contiguous dimension - if constexpr(std::is_same_v) - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{})); - } - // M is contiguous dimension - else if constexpr(std::is_same_v) - { - return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number<1>{}, number{})); - } - else - { - static_assert(false, "Unsupported ELayout!"); - } - } - - CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() - { - constexpr auto block_outer_dstr_encoding = [] { - if constexpr(BlockedXDLN_PerWarp == 1) - { - return tile_distribution_encoding, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - } - else - { - constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp; - // BlockedLayout - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}; - } - }(); - constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( - block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{}); - - return block_dstr_encoding; - } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); - } - - template = 0> - CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, - const OAccTile& o_acc_tile, - const DsDramWindows& ds_dram_windows, - void* p_smem) - { - constexpr int kM0 = MWave; - constexpr int kM2 = 4; - constexpr int kM1 = MPerXdl / kM2; - - constexpr int kN0 = NWave; - constexpr int kN1 = NPerXdl; - constexpr int kN2 = NRepeat; - - using IntrThreadShuffleEncode = - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 1>>, - sequence<1, 2>, - sequence<2, 2>>; - static_assert(GetVectorSizeC() % kN2 == 0); - - constexpr auto dram_tile_distribution = - make_static_tile_distribution(IntrThreadShuffleEncode{}); - - auto d_dram_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); - }, - number{}); - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - auto shuffle_acc = make_static_distributed_tensor(dram_tile_distribution); - auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); - - static_for<0, MRepeat, 1>{}([&](auto mIter) { - shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - - static_for<0, NRepeat, 1>{}([&](auto n_idx) { - // transpose thread matrix - c_out_tensor.get_thread_buffer()[n_idx + 0 * NRepeat] = type_convert( - shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0]); - c_out_tensor.get_thread_buffer()[n_idx + 1 * NRepeat] = type_convert( - shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1]); - c_out_tensor.get_thread_buffer()[n_idx + 2 * NRepeat] = type_convert( - shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2]); - c_out_tensor.get_thread_buffer()[n_idx + 3 * NRepeat] = type_convert( - shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3]); - }); - - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile(out_dram_window, c_out_tensor); - } - else - { - update_tile(out_dram_window, c_out_tensor); - } - move_tile_window(out_dram_window, {number{}, number<0>{}}); - - static_for<0, NumDTensor, 1>{}([&](auto idx) { - move_tile_window(d_dram_windows[idx], {number{}, number<0>{}}); - }); - }); - } - - template = 0> - CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, - const OAccTile& o_acc_tile, - const DsDramWindows& ds_dram_windows, - void* p_smem) - { - constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); - - auto lds_tile = make_static_distributed_tensor(LdsTileDistr); - - constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); - auto o_lds_block = make_tensor_view( - static_cast(p_smem), lds_block_desc); - - auto in_lds_window = make_tile_window( - o_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - LdsTileDistr); - - auto out_lds_window = make_tile_window( - o_lds_block, - make_tuple(number{}, number{}), - {0, 0}); - - using SFC = space_filling_curve, - sequence<0, 1>, - sequence>; - constexpr index_t num_access = SFC::get_num_of_access(); - - static_assert(std::is_same_v, - "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); - - using TileEncodingPattern = - TileDistributionEncodingPattern2D; - constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); - - auto d_dram_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); - }, - number{}); - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - static_for<0, num_access, 1>{}([&](auto iAccess) { - block_sync_lds(); - constexpr auto idx_y_start = SFC::get_index(iAccess); - - constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; - constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; - - lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); - - const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); - - store_tile(in_lds_window, c_warptile_in_tensor_casted); - block_sync_lds(); - - auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - - const auto ds_tensor = generate_tuple( - [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); - - const auto c_ds_tiles = concat_tuple_of_reference( - tie(c_out_tensor, c_out_tensor), - generate_tie( - [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); - - tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); - - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile(out_dram_window, c_out_tensor); - } - else - { - update_tile(out_dram_window, c_out_tensor); - } - if constexpr(iAccess != num_access - 1) - { - constexpr auto step = SFC::get_forward_step(iAccess); - - move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); - - static_for<0, NumDTensor, 1>{}([&](auto idx) { - move_tile_window(d_dram_windows[idx], - {step.at(number<0>{}), step.at(number<1>{})}); - }); - } - }); - } - - template = 0> - CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, - const OAccTile& o_acc_tile, - const DsDramWindows& ds_dram_windows, - void* p_smem, - ScaleMWindow scale_m_window, - ScaleNWindow scale_n_window) - { - constexpr int kM0 = MWave; - constexpr int kM2 = 4; - constexpr int kM1 = MPerXdl / kM2; - static_assert(MPerXdl == 16, "TiledMMAPermuteN only supports MPerXdl = 16 now"); - - constexpr int kN0 = NWave; - constexpr int kN1 = NPerXdl; - constexpr int kN2 = NRepeat; - - using IntrThreadShuffleEncode = - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 1>>, - sequence<1, 2>, - sequence<2, 2>>; - static_assert(GetVectorSizeC() % kN2 == 0); - - constexpr auto dram_tile_distribution = - make_static_tile_distribution(IntrThreadShuffleEncode{}); - - constexpr int DynamicTileOffsetFlag = 0; - - auto permute_scale_n_view_1 = transform_tensor_view( - scale_n_window.get_bottom_tensor_view(), - make_tuple(make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, - number{}, - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2, 3, 4>{})); - auto permute_scale_n_view = transform_tensor_view( - permute_scale_n_view_1, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod(make_tuple(number{}, - number{}, - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 4, 2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - auto scale_m_window_with_dist = make_tile_window( - scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution()); - auto scale_n_window_with_dist = make_tile_window(permute_scale_n_view, - scale_n_window.get_window_lengths(), - scale_n_window.get_window_origin(), - o_acc_tile.get_tile_distribution()); - - auto scale_m_buffer = load_tile(scale_m_window_with_dist); - auto scale_n_buffer = load_tile(scale_n_window_with_dist); - - auto d_dram_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); - }, - number{}); - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - using ShuffleAcc = - decltype(make_static_distributed_tensor(dram_tile_distribution)); - ShuffleAcc shuffle_acc[MRepeat]; - auto c_out_tensor_fp32 = - make_static_distributed_tensor(dram_tile_distribution); - auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); - - constexpr int NumAccPerEpiTile = NRepeat * c_warp_y_lengths.product(); - - static_for<0, MRepeat, 1>{}([&](auto mIter) { - shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - auto epi_scale_n = scale_n_buffer.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - - static_for<0, NumAccPerEpiTile, 1>{}( - [&](auto i) { shuffle_acc[mIter].get_thread_buffer()[i] *= epi_scale_n[i]; }); - }); - - static_for<0, MRepeat, 1>{}([&](auto mIter) { - auto epi_scale_m = scale_m_buffer.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - - static_for<0, NRepeat, 1>{}([&](auto n_idx) { - // transpose thread matrix - c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] = - shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] * - epi_scale_m[n_idx * c_warp_y_lengths.product() + 0]; - c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] = - shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] * - epi_scale_m[n_idx * c_warp_y_lengths.product() + 1]; - c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] = - shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] * - epi_scale_m[n_idx * c_warp_y_lengths.product() + 2]; - c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] = - shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] * - epi_scale_m[n_idx * c_warp_y_lengths.product() + 3]; - }); - - c_out_tensor = cast_tile(c_out_tensor_fp32); - - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile(out_dram_window, c_out_tensor); - } - else - { - update_tile(out_dram_window, c_out_tensor); - } - move_tile_window(out_dram_window, {number{}, number<0>{}}); - - static_for<0, NumDTensor, 1>{}([&](auto idx) { - move_tile_window(d_dram_windows[idx], {number{}, number<0>{}}); - }); - }); - } - - template = 0> - CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, - const OAccTile& o_acc_tile, - const DsDramWindows& ds_dram_windows, - void* p_smem, - ScaleMWindow scale_m_window, - ScaleNWindow scale_n_window) - { - constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); - - using LDSTileTensor = decltype(make_static_distributed_tensor(LdsTileDistr)); - LDSTileTensor lds_tile[2]; - - constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); - auto o_lds_block = make_tensor_view( - static_cast(p_smem), lds_block_desc); - - auto in_lds_window = make_tile_window( - o_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - LdsTileDistr); - - auto out_lds_window = make_tile_window( - o_lds_block, - make_tuple(number{}, number{}), - {0, 0}); - - using SFC = space_filling_curve, - sequence<0, 1>, - sequence>; - constexpr index_t num_access = SFC::get_num_of_access(); - - static_assert(std::is_same_v, - "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); - - using TileEncodingPattern = - TileDistributionEncodingPattern2D; - constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); - - auto d_dram_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); - }, - number{}); - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - auto scale_m_window_with_dist = make_tile_window( - scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution()); - auto scale_n_window_with_dist = make_tile_window( - scale_n_window, scale_n_window.get_window_origin(), o_acc_tile.get_tile_distribution()); - - auto scale_m_buffer = load_tile(scale_m_window_with_dist); - auto scale_n_buffer = load_tile(scale_n_window_with_dist); - - constexpr int NumAccPerEpiTile = - NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle * c_warp_y_lengths.product(); - auto epi_tile_idx_slice = - [&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) { - return acc_tile_like_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); - }; - - lds_tile[0].get_thread_buffer() = epi_tile_idx_slice(o_acc_tile, number<0>{}, number<0>{}); - - auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, number<0>{}, number<0>{}); - auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, number<0>{}, number<0>{}); - static_for<0, NumAccPerEpiTile, 1>{}( - [&](auto i) { lds_tile[0].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; }); - - static_for<0, num_access, 1>{}([&](auto iAccess) { - constexpr int read_stage = iAccess % 2; - constexpr int write_stage = read_stage ^ 1; - - block_sync_lds(); - constexpr auto idx_y_start = SFC::get_index(number{}); - - constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; - constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; - - const auto c_warptile_in_tensor_casted = cast_tile(lds_tile[read_stage]); - - store_tile(in_lds_window, c_warptile_in_tensor_casted); - - if constexpr(iAccess < num_access - 1) - { - lds_tile[write_stage].get_thread_buffer() = - epi_tile_idx_slice(o_acc_tile, mIter, nIter); - - epi_scale_m = epi_tile_idx_slice(scale_m_buffer, mIter, nIter); - epi_scale_n = epi_tile_idx_slice(scale_n_buffer, mIter, nIter); - - static_for<0, NumAccPerEpiTile, 1>{}([&](auto i) { - lds_tile[write_stage].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; - }); - } - - block_sync_lds(); - - auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - - const auto ds_tensor = generate_tuple( - [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); - - const auto c_ds_tiles = concat_tuple_of_reference( - tie(c_out_tensor, c_out_tensor), - generate_tie( - [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); - - tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); - - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile(out_dram_window, c_out_tensor); - } - else - { - update_tile(out_dram_window, c_out_tensor); - } - if constexpr(iAccess != num_access - 1) - { - constexpr auto step = SFC::get_forward_step(iAccess); - - move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); - - static_for<0, NumDTensor, 1>{}([&](auto idx) { - move_tile_window(d_dram_windows[idx], - {step.at(number<0>{}), step.at(number<1>{})}); - }); - } - }); - } -}; -} // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index d605be6af2..30c0c20d0e 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -414,7 +414,8 @@ struct CShuffleEpilogue GetVectorSizeC(), tile_distribution_pattern::thread_raked, Problem::kNumWaveGroups>; - constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + constexpr auto dram_tile_distribution = + TileEncodingPattern::make_2d_static_tile_distribution(); auto d_dram_windows = generate_tuple( [&](auto idx) { @@ -482,16 +483,16 @@ struct CShuffleEpilogue template = 0> CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* p_smem, - ScaleMWindow scale_m_window, - ScaleNWindow scale_n_window) + ScaleM scale_m, + ScaleN scale_n) { constexpr int kM0 = MWave; constexpr int kM2 = 4; @@ -509,43 +510,9 @@ struct CShuffleEpilogue tuple, sequence<1, 1>>, sequence<1, 2>, sequence<2, 2>>; - static_assert(GetVectorSizeC() % kN2 == 0); - constexpr auto dram_tile_distribution = make_static_tile_distribution(IntrThreadShuffleEncode{}); - constexpr int DynamicTileOffsetFlag = 0; - - auto permute_scale_n_view_1 = transform_tensor_view( - scale_n_window.get_bottom_tensor_view(), - make_tuple(make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, - number{}, - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2, 3, 4>{})); - auto permute_scale_n_view = transform_tensor_view( - permute_scale_n_view_1, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod(make_tuple(number{}, - number{}, - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 4, 2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - auto scale_m_window_with_dist = make_tile_window( - scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution()); - auto scale_n_window_with_dist = make_tile_window(permute_scale_n_view, - scale_n_window.get_window_lengths(), - scale_n_window.get_window_origin(), - o_acc_tile.get_tile_distribution()); - - auto scale_m_buffer = load_tile(scale_m_window_with_dist); - auto scale_n_buffer = load_tile(scale_n_window_with_dist); - auto d_dram_windows = generate_tuple( [&](auto idx) { return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); @@ -563,39 +530,56 @@ struct CShuffleEpilogue make_static_distributed_tensor(dram_tile_distribution); auto c_out_tensor = make_static_distributed_tensor(dram_tile_distribution); - constexpr int NumAccPerEpiTile = NRepeat * c_warp_y_lengths.product(); + const index_t iMWarp = get_warp_id() / NWave; + const index_t iNWarp = get_warp_id() - iMWarp * NWave; + const index_t iMLane = get_lane_id() / NPerXdl; + const index_t iNLane = get_lane_id() % NPerXdl; + + float vec_scale_A[kM2 * MRepeat]; + float vec_scale_B[NRepeat]; + + _Pragma("unroll") for(int i = 0; i < NRepeat; ++i) + { + vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl]; + } + _Pragma("unroll") for(int i = 0; i < MRepeat; ++i) + { + vec_scale_A[i * kM2 + 0] = + scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; + vec_scale_A[i * kM2 + 1] = + scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; + vec_scale_A[i * kM2 + 2] = + scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; + vec_scale_A[i * kM2 + 3] = + scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; + } static_for<0, MRepeat, 1>{}([&](auto mIter) { shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - auto epi_scale_n = scale_n_buffer.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - - static_for<0, NumAccPerEpiTile, 1>{}( - [&](auto i) { shuffle_acc[mIter].get_thread_buffer()[i] *= epi_scale_n[i]; }); + static_for<0, NRepeat, 1>{}([&](auto n_idx) { + shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 0] *= vec_scale_B[n_idx]; + shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 1] *= vec_scale_B[n_idx]; + shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 2] *= vec_scale_B[n_idx]; + shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 3] *= vec_scale_B[n_idx]; + }); }); static_for<0, MRepeat, 1>{}([&](auto mIter) { - auto epi_scale_m = scale_m_buffer.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - static_for<0, NRepeat, 1>{}([&](auto n_idx) { - // transpose thread matrix c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] = shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] * - epi_scale_m[n_idx * c_warp_y_lengths.product() + 0]; + vec_scale_A[mIter * kM2 + 0]; c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] = shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] * - epi_scale_m[n_idx * c_warp_y_lengths.product() + 1]; + vec_scale_A[mIter * kM2 + 1]; c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] = shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] * - epi_scale_m[n_idx * c_warp_y_lengths.product() + 2]; + vec_scale_A[mIter * kM2 + 2]; c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] = shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] * - epi_scale_m[n_idx * c_warp_y_lengths.product() + 3]; + vec_scale_A[mIter * kM2 + 3]; }); c_out_tensor = cast_tile(c_out_tensor_fp32); @@ -619,16 +603,16 @@ struct CShuffleEpilogue template = 0> CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* p_smem, - ScaleMWindow scale_m_window, - ScaleNWindow scale_n_window) + ScaleM scale_m, + ScaleN scale_n) { constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); @@ -650,6 +634,10 @@ struct CShuffleEpilogue make_tuple(number{}, number{}), {0, 0}); + // using SFC = space_filling_curve, + // sequence<0, 1>, + // sequence>; constexpr index_t num_access = SFC::get_num_of_access(); static_assert(std::is_same_v, @@ -662,7 +650,8 @@ struct CShuffleEpilogue GetVectorSizeC(), tile_distribution_pattern::thread_raked, Problem::kNumWaveGroups>; - constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + constexpr auto dram_tile_distribution = + TileEncodingPattern::make_2d_static_tile_distribution(); auto d_dram_windows = generate_tuple( [&](auto idx) { @@ -674,32 +663,63 @@ struct CShuffleEpilogue to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - auto scale_m_window_with_dist = make_tile_window( - scale_m_window, scale_m_window.get_window_origin(), o_acc_tile.get_tile_distribution()); - auto scale_n_window_with_dist = make_tile_window( - scale_n_window, scale_n_window.get_window_origin(), o_acc_tile.get_tile_distribution()); + constexpr int kM2 = 4; // Val + constexpr int kM1 = (64 / NPerXdl); // Thr + constexpr int kM0 = MPerXdl / kM1 / kM2; // Val - auto scale_m_buffer = load_tile(scale_m_window_with_dist); - auto scale_n_buffer = load_tile(scale_n_window_with_dist); + const index_t iMWarp = get_warp_id() / NWave; + const index_t iNWarp = get_warp_id() - iMWarp * NWave; + const index_t iMLane = get_lane_id() / NPerXdl; + const index_t iNLane = get_lane_id() % NPerXdl; - constexpr int NumAccPerEpiTile = - NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle * c_warp_y_lengths.product(); - auto epi_tile_idx_slice = - [&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) { - return acc_tile_like_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); - }; + float vec_scale_A[kM0 * kM2 * MRepeat]; + float vec_scale_B[NRepeat]; - lds_tile[0].get_thread_buffer() = epi_tile_idx_slice(o_acc_tile, number<0>{}, number<0>{}); + _Pragma("unroll") for(int i = 0; i < NRepeat; ++i) + { + vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane]; + } + _Pragma("unroll") for(int i = 0; i < MRepeat; ++i) + { + _Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0) + { + vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 0] = + scale_m[0 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl + + i * MPerXdl * MWave]; + vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 1] = + scale_m[1 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl + + i * MPerXdl * MWave]; + vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 2] = + scale_m[2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl + + i * MPerXdl * MWave]; + vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 3] = + scale_m[3 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl + + i * MPerXdl * MWave]; + } + } - auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, number<0>{}, number<0>{}); - auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, number<0>{}, number<0>{}); - static_for<0, NumAccPerEpiTile, 1>{}( - [&](auto i) { lds_tile[0].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; }); + lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{}, + c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_lengths)); + static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { + static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { + constexpr int acc_xdl_offset = + (m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product(); + _Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0) + { + lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *= + vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 0] * vec_scale_B[n_xdl]; + lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *= + vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 1] * vec_scale_B[n_xdl]; + lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *= + vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 2] * vec_scale_B[n_xdl]; + lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *= + vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 3] * vec_scale_B[n_xdl]; + } + }); + }); static_for<0, num_access, 1>{}([&](auto iAccess) { constexpr int read_stage = iAccess % 2; @@ -717,14 +737,40 @@ struct CShuffleEpilogue if constexpr(iAccess < num_access - 1) { - lds_tile[write_stage].get_thread_buffer() = - epi_tile_idx_slice(o_acc_tile, mIter, nIter); - - epi_scale_m = epi_tile_idx_slice(scale_m_buffer, mIter, nIter); - epi_scale_n = epi_tile_idx_slice(scale_n_buffer, mIter, nIter); - - static_for<0, NumAccPerEpiTile, 1>{}([&](auto i) { - lds_tile[write_stage].get_thread_buffer()[i] *= epi_scale_m[i] * epi_scale_n[i]; + lds_tile[write_stage].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_lengths)); + static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { + static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { + constexpr int acc_xdl_offset = + (m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product(); + _Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0) + { + lds_tile[write_stage] + .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *= + vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 + + m_xdl * kM0 * kM2 + m0 * kM2 + 0] * + vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; + lds_tile[write_stage] + .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *= + vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 + + m_xdl * kM0 * kM2 + m0 * kM2 + 1] * + vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; + lds_tile[write_stage] + .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *= + vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 + + m_xdl * kM0 * kM2 + m0 * kM2 + 2] * + vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; + lds_tile[write_stage] + .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *= + vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 + + m_xdl * kM0 * kM2 + m0 * kM2 + 3] * + vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; + } + }); }); } diff --git a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp index eb2d27443a..6e8d2d5337 100644 --- a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp @@ -22,13 +22,13 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; + static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize; static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel; using ADataType = remove_cvref_t; diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index 346c82a129..b2a5f39793 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -22,13 +22,13 @@ struct MXFlatmmKernel : FlatmmKernel; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; + static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize; static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel; using ADataType = remove_cvref_t; @@ -36,9 +36,8 @@ struct MXFlatmmKernel : FlatmmKernel; - using BlockGemm = remove_cvref_t; - static constexpr int MThreadPerXdl = BlockGemm::WarpGemm::kM; - static constexpr int NThreadPerXdl = BlockGemm::WarpGemm::kN; + static constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{}); + static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{}); static constexpr int KThreadPerXdl = 64 / MThreadPerXdl; static constexpr int APackedSize = numeric_traits::PackedSize; diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index f89f2d3d57..aa3e0493b1 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -39,10 +39,10 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem @@ -51,7 +51,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1; using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape