From fb76450e6375d7c6b761ce0af0462c952dc46f5b Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Wed, 16 Jul 2025 10:12:19 +0000 Subject: [PATCH 01/55] merge from dteng_flatmm_opt --- example/ck_tile/18_flatmm/CMakeLists.txt | 11 +- example/ck_tile/18_flatmm/flatmm_basic.hpp | 117 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 6 + .../core/arch/amd_buffer_addressing.hpp | 8 +- .../block_flatmm_asmem_bsmem_creg_v1.hpp | 1 + .../ops/flatmm/kernel/flatmm_kernel.hpp | 26 +- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1249 +++++++++++++---- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 31 + .../gemm/pipeline/gemm_pipeline_problem.hpp | 142 ++ 9 files changed, 1324 insertions(+), 267 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 6d6b71ea18..87237458c5 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -2,5 +2,14 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32_F8=1 -Wno-unused-local-typedef) +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16_F8=1 -Wno-unused-local-typedef) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x128_F8=1 -Wno-unused-local-typedef) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0") +#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1") target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 963a6ba675..138696eace 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -164,6 +164,120 @@ struct is_8bit_type { }; +template +struct GemmConfig +{ +#if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16 + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; +#elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune) + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +#elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune) + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +#elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune) + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16 + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +#elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune) + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 8; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +#elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune) + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +#elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune) + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +#else + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; +#endif +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -177,7 +291,8 @@ auto create_args(int argc, char* argv[]) .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index b583612cfb..b5957a7c53 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -254,6 +254,12 @@ int run_flatmm_example_with_layouts(int argc, c_dev_buf.FromDevice(c_rslt_host.data()); bool pass = true; + // for(int i=0;i(c_rslt_host.mData.size());i++){ + // printf("dteng---a[%d][%d]=%f\n",i/256,i%256,ck_tile::type_convert(a_host.mData[i])); + // } + // for(int i=0;i(c_rslt_host.mData.size());i++){ + // printf("dteng---c[%d][%d]=%f\n",i/256,i%256,ck_tile::type_convert(c_rslt_host.mData[i])); + // } if(arg_parser.get_int("v") == 1) { diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index aafc6c0a85..ca5ed41dd4 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -41,10 +41,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz { buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); + // r.x = __builtin_amdgcn_readfirstlane(r.x); + // r.y = __builtin_amdgcn_readfirstlane(r.y); + // r.z = __builtin_amdgcn_readfirstlane(r.z); + // r.w = __builtin_amdgcn_readfirstlane(r.w); return r; } diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 18b2fe6483..85494b3a76 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -113,6 +113,7 @@ struct BlockFlatmmASmemBSmemCRegV1 merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); }); }); }); diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 76df056ea6..607645c097 100755 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -146,10 +146,14 @@ struct FlatmmKernel hostArgs.k_batch}; } - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize() { return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize() + { + return FlatmmPipeline::GetSmemSize(); + } struct SplitKBatchOffset { @@ -560,7 +564,8 @@ struct FlatmmKernel const BDataType* b_flat_ptr, const std::array& ds_ptr, EDataType* e_ptr, - void* smem_ptr, + void* smem_ptr_ping, + void* smem_ptr_pong, const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, @@ -580,15 +585,16 @@ struct FlatmmKernel const auto& b_flat_block_window = gemm_tile_windows.at(I1); const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = FlatmmPipeline{}.template operator()( - a_block_window, b_flat_block_window, num_loop, smem_ptr); + a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong); + + // Run Epilogue Pipeline + if(UseDefaultScheduler || (get_warp_id() == 0)) { // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_ping); } } @@ -607,7 +613,8 @@ struct FlatmmKernel EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + __shared__ char smem_ptr_ping[GetSmemPingSize()]; + __shared__ char smem_ptr_pong[GetSmemPongSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && @@ -618,7 +625,8 @@ struct FlatmmKernel b_flat_ptr, kargs.ds_ptr, e_ptr, - smem_ptr, + smem_ptr_ping, + smem_ptr_pong, kargs, splitk_batch_offset, i_m, diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index edb5853c7f..7e239d00a4 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -3,6 +3,7 @@ #pragma once +// #define FINEGRADE_LOADSTORE #include "ck_tile/core.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" @@ -12,26 +13,24 @@ namespace ck_tile { template struct BaseFlatmmPipelineAGmemBGmemCRegV1 { - static constexpr index_t PrefetchStages = 1; - static constexpr index_t PrefillStages = 1; - static constexpr index_t GlobalBufferNum = 1; - static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + static constexpr index_t PrefetchStages = 2; - CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - - CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } - - CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) { - return TailNumber::Empty; + return num_loop > PrefetchStages; } + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } template CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) { return run_func(bool_constant{}, integral_constant{}); } }; + template struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV1 { @@ -47,6 +46,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV using BlockFlatmm = remove_cvref_t())>; + + static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; static constexpr index_t BlockSize = Problem::kBlockSize; @@ -73,17 +76,70 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static constexpr index_t kLdsAlignmentInBytes = 16; static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - static constexpr auto I0 = number<0>(); - static constexpr auto I1 = number<1>(); - static constexpr auto I2 = number<2>(); + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; - using BlockTile = remove_cvref_t; - using BlockWarps = remove_cvref_t; - using WarpTile = remove_cvref_t; + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t Preshuffle = Problem::Preshuffle; - using Base::UsePersistentKernel; + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + static constexpr index_t K1 = 16 / sizeof(ADataType); + static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; + static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp; + static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum; + static constexpr index_t BloadGap = MIterPerWarp / 2; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + + /* + defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1 + defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1 + defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1 + defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1 + + defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1 + defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1 + defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1 + defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1 + + defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1 + defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 + */ + + #if (defined(USING_MFMA_16x16x32_F8) || \ + defined(USING_MFMA_32x32x16_F8) || \ + defined(USING_MFMA_16x16x16_F16) || \ + defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5 + static constexpr auto mfma_per_wg = 2; + static constexpr auto dsread_per_wg = 1; + #elif (defined(USING_MFMA_16x16x32_F16) || \ + defined(USING_MFMA_32x32x16_F16) || \ + defined(USING_MFMA_16x16x128_F4) || \ + defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1 + static constexpr auto mfma_per_wg = 1; + static constexpr auto dsread_per_wg = 1; + #elif (defined(USING_MFMA_16x16x128_F8) || \ + defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2 + static constexpr auto mfma_per_wg = 1; + static constexpr auto dsread_per_wg = 2; + #endif [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -104,83 +160,369 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { - constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + // Keypoint of pipeline optimize is workload balance in time + // instruction schedule example(128X256X256, 1X4, 16X16X128): + // Iter MNK MFMA ds_read ds_write A_load b_load + // -1 M6N3: 60 2 - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 - - - - + // -1 M7N2: 63 - - - - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - - + // 0 M0N1: 2 - - - 2 + // 0 M0N2: 3 - - - - + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - - + // 0 M1N1: 6 - - - 4 + // 0 M1N2: 7 - - - - + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - - + // 0 M2N1: 10 - - - 6 + // 0 M2N2: 11 - - - - + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - - + // 0 M3N1: 14 - - - 8 + // 0 M3N2: 15 - - - - + // 0 M3N3: 16 12 - - - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 - - - - + // 0 M4N2: 19 - - 1 - + // 0 M4N3: 20 14 - - - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 - - - - + // 0 M5N2: 23 - - 2 - + // 0 M5N3: 24 16 - - - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 - - - - + // 0 M6N2: 27 - - 3 - + // 0 M6N3: 28 17 - - - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 - - - - + // 0 M7N2: 31 - - 4 - + // 0 M7N3: 32 18 - - - + // 0 M0N0K1: 33 - - - - + // 0 M0N1: 34 - - - 10 + // 0 M0N2: 35 - - - - + // 0 M0N3: 36 20 - - - + // 0 M1N0: 37 - - - - + // 0 M1N1: 38 - - - 12 + // 0 M1N2: 39 - - - - + // 0 M1N3: 40 22 - - - + // 0 M2N0: 41 - - - - + // 0 M2N1: 42 - - - 14 + // 0 M2N2: 43 - - - - + // 0 M2N3: 44 24 - - - + // 0 M3N0: 45 - 5 - - + // 0 M3N1: 46 - - - 16 + // 0 M3N2: 47 - - - - + // 0 M3N3: 48 26 - - - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 - - - - + // 0 M4N2: 51 - - 5 - + // 0 M4N3: 52 28 - - - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 - - - - + // 0 M5N2: 55 - - 6 - + // 0 M5N3: 56 30 - - - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 - - - - + // 0 M6N2: 59 - - 7 - + // 0 M6N3: 60 2 - - - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 - - - - + // 0 M7N2: 63 - - 8 - + // 0 M7N3: 64 4 - - - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); - - constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; - constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; - constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; - - if constexpr(WG::kM == 16 && WG::kN == 16) - { - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + #if 0 // MI350 FP8 16X16 128*256*256 + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA - }); - static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - }); - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA - }); - } - else if constexpr(WG::kM == 32 && WG::kN == 32 && - (A_LDS_Read_Inst_Num / 2 > - A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) - { - static_for<0, - A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num, - 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; + + __builtin_amdgcn_sched_barrier(0); + #endif + #if 0 // MI350 FP8 16X16 + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); - __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA - } + + __builtin_amdgcn_sched_barrier(0); + #endif + #if 0 // MI300 FP8 16X16 128*128*128 + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + + __builtin_amdgcn_sched_barrier(0); + #endif + #if 0 // MI300 FP8 16X16 128*256*128 + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + + __builtin_amdgcn_sched_barrier(0); + #endif + #if 0 //MI300 FP8 16X16 16*64*256 + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_barrier(0); + #endif + } + + + CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler() + { + #if 0 + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_barrier(0); + #endif } template @@ -188,7 +530,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, - void* p_smem) const + void* p_smem_ping, + void* p_smem_pong) const { static_assert( std::is_same_v> && @@ -197,7 +540,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV "([A|B]DataType) defined in Problem definition!"); constexpr bool is_a_col_major = std::is_same_v; - static_assert(is_a_col_major ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]) @@ -205,69 +547,137 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]), "A block window has incorrect lengths for defined ALayout!"); - constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; - constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; - - constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; - constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; const index_t iMWarp = get_warp_id() / NWarp; + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + __builtin_amdgcn_sched_barrier(0); + // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); + ADataType* p_a_lds_ping = static_cast(p_smem_ping); + ADataType* p_a_lds_pong = static_cast(p_smem_pong); constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor(); - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + auto a_lds_block_ping = make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = make_tensor_view(p_a_lds_pong, a_lds_block_desc); // A DRAM tile window for load + #ifndef FINEGRADE_LOADSTORE auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), PipelinePolicy::template MakeADramTileDistribution()); - // A LDS tile window for store - auto a_copy_lds_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + #else + auto a_copy_dram_window_tmp = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramDistribution()); + + statically_indexed_array a_copy_dram_window; + static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { + a_copy_dram_window(AIter) = a_copy_dram_window_tmp; + move_tile_window(a_copy_dram_window(AIter), {AIter * AcopyPerLoadM, 0}); + }); + + auto a_copy_lds_window_ping_tmp = make_tile_window( + a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramDistribution() + ); + + statically_indexed_array a_copy_lds_window_ping; + static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { + a_copy_lds_window_ping(AIter) = a_copy_lds_window_ping_tmp; + move_tile_window(a_copy_lds_window_ping(AIter), {AIter * AcopyPerLoadM, 0}); + }); + + auto a_copy_lds_window_pong_tmp = make_tile_window( + a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramDistribution() + ); + + statically_indexed_array a_copy_lds_window_pong; + static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { + a_copy_lds_window_pong(AIter) = a_copy_lds_window_pong_tmp; + move_tile_window(a_copy_lds_window_pong(AIter), {AIter * AcopyPerLoadM, 0}); + }); + #endif // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); + // auto a_lds_gemm_window = make_tile_window( + // a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}); - auto a_warp_window_tmp = make_tile_window( - a_lds_gemm_window.get_bottom_tensor_view(), + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = make_tile_window( + a_lds_block_ping, make_tuple(number{}, number{}), - a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + auto a_warp_window_pong_tmp = make_tile_window( + a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); statically_indexed_array< - statically_indexed_array, + statically_indexed_array, MIterPerWarp> - a_warp_windows; + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - move_tile_window(a_warp_windows(mIter)(kIter), + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + move_tile_window(a_warp_windows_pong(mIter)(kIter), {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); // Block GEMM auto block_flatmm = BlockFlatmm(); + // Acc register tile + auto c_block_tile = block_flatmm.MakeCBlockTile(); // B flat DRAM window for load auto b_flat_distribution = @@ -279,13 +689,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV b_flat_dram_block_window_tmp.get_window_origin(), b_flat_distribution); - // Acc register tile - auto c_block_tile = block_flatmm.MakeCBlockTile(); - - // prefetch - // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - + // pingpong buffer for B statically_indexed_array< statically_indexed_array, NIterPerWarp> @@ -294,158 +698,497 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV statically_indexed_array< statically_indexed_array, NIterPerWarp> - b_warp_tensor; + b_warp_tensor_ping; statically_indexed_array< statically_indexed_array, NIterPerWarp> - b_warp_tensor_2; + b_warp_tensor_pong; + + // Prefetch A0 + #ifndef FINEGRADE_LOADSTORE + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + #else + statically_indexed_array{}))), ACopyLoadNum> a_block_tile; + static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { + a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter)); + move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock}); + }); + #endif + + // prefetch B static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - { - // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + // Prefill A0 + // if constexpr(std::is_same_v) + // { + // auto a_shuffle_tmp = make_static_distributed_tensor( + // PipelinePolicy::template MakeShuffledARegBlockDistribution()); + // shuffle_tile(a_shuffle_tmp, a_block_tile); + // const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + // store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + // } + // else + // { + // store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile)); + // } + #ifndef FINEGRADE_LOADSTORE + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + #else + static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { + store_tile(a_copy_lds_window_ping(AIter), tile_elementwise_in(a_element_func, a_block_tile(AIter))); + }); + #endif + __builtin_amdgcn_sched_barrier(0); - // move to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // Prefetch A1 + #ifndef FINEGRADE_LOADSTORE + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + #else + static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { + a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter)); + move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock}); + }); + #endif - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - // LDS write 0 - if constexpr(std::is_same_v) + block_sync_lds(); + + // preload A00,A10 from lds + constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2: 1; + statically_indexed_array{})(number<0>{}))), m_preload> a_warp_tensor_ping; + statically_indexed_array{})(number<0>{}))), m_preload> a_warp_tensor_pong; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); + }); + __builtin_amdgcn_sched_barrier(0); + + // if(threadIdx.x==0){ + // for(int i=0;i(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size()); + // } + // } + // for(int i=0;i{}).get_thread_buffer_size();i++) { + // printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n", threadIdx.x, type_convert(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size()); + // } + + + index_t iCounter = (num_loop - 1) / 2; + // if constexpr(HasMainLoop) + // { + while(iCounter > 0) { - auto a_shuffle_tmp = make_static_distributed_tensor( - PipelinePolicy::template MakeShuffledARegBlockDistribution()); - shuffle_tile(a_shuffle_tmp, a_block_tile); - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); - store_tile(a_copy_lds_window, a_block_tile_tmp); + #ifndef FINEGRADE_LOADSTORE + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + #endif + + // GEMM 2i + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + #ifdef FINEGRADE_LOADSTORE + // prefetch B(2i+1) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) + { + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window(b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); + } + // Prefill A(2i+1) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) + { + constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; + store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } + // Prefetch A(2i+2) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2))) + { + constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum; + a_block_tile(number{}) = load_tile(a_copy_dram_window(number{})); + move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); + } + #endif + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); + } + + //barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + //block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); + + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); + }); + HotLoopScheduler(); + + //Next K + + #ifndef FINEGRADE_LOADSTORE + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + #endif + + // GEMM 2i+1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_pong(number{}), b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + #ifdef FINEGRADE_LOADSTORE + // prefetch B(2i+2) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) + { + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window(b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_ping(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); + } + // Prefill A(2i+1) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) + { + constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; + store_tile(a_copy_lds_window_ping(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } + // Prefetch A(2i+2) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2))) + { + constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum; + a_block_tile(number{}) = load_tile(a_copy_dram_window(number{})); + move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); + } + #endif + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); + } + + //barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); + + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); + }); + HotLoopScheduler(); + + iCounter--; } - else + + // tail + if constexpr(TailNum == TailNumber::Even) { - store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + // __builtin_amdgcn_sched_barrier(0); + #ifndef FINEGRADE_LOADSTORE + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + #endif + + // GEMM loopK-1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + #ifdef FINEGRADE_LOADSTORE + // prefetch B(loopK) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) + { + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window(b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); + } + // Prefill A(loopK) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) + { + constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; + store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } + #endif + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); + } + + //barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + //block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); + + TailHotLoopScheduler(); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // __builtin_amdgcn_sched_barrier(0); + + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_pong(number{}), b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); + }); + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); + } + }); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); + + // TailHotLoopScheduler(); + // __builtin_amdgcn_sched_barrier(0); } - block_sync_lds(); - } - - index_t iCounter = num_loop / 2 - 1; - while(iCounter > 0) - { - // global read i + 1 - a_block_tile = load_tile(a_copy_dram_window); - - // GEMM i - block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); - - block_sync_lds(); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + #ifdef FINEGRADE_LOADSTORE + // prefetch B(loopK) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) + { + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window(b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); + } + // Prefill A(loopK) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) + { + constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; + store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } + #endif + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); + } - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + //barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); }); - }); - - // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // move to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // LDS write i + 1 - auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); - HotLoopScheduler(); - block_sync_lds(); - - // iCounter--; - - // global read i + 1 - a_block_tile = load_tile(a_copy_dram_window); - - // GEMM i - block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); - - block_sync_lds(); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // move to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // LDS write i + 1 - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); - - HotLoopScheduler(); - block_sync_lds(); - - iCounter--; - } - - // tail - { - // global read i + 1 - a_block_tile = load_tile(a_copy_dram_window); - - // GEMM i - block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); - - block_sync_lds(); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // move to i + 2 - // move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); - - // move to next flat K - // move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - HotLoopScheduler(); - block_sync_lds(); - - // GEMM num_loop - 1 - block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); - } + } + // } return c_block_tile; } @@ -454,14 +1197,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, - void* p_smem) const + void* p_smem_ping, + void* p_smem_pong) const { return operator()( a_dram_block_window_tmp, [](const ADataType& a) { return a; }, b_flat_dram_block_window_tmp, num_loop, - p_smem); + p_smem_ping, + p_smem_pong); } }; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 837aeb13e3..9e1daa6f16 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -340,6 +340,37 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + // constexpr index_t M0 = MPerBlock / (M2 * M1); + // static_assert(M0 * M1 * M2 == MPerBlock, + // "Incorrect M0, M2, M1 configuration! " + // "M0, M1, M2 must cover whole MPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<2>, + sequence<1>>{}); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index c19d42ce25..449f567eb1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -178,6 +178,148 @@ using GemmPipelineProblem = GemmPipelineProblemBase; +template +struct FlatmmPipelineProblem +{ + using Traits = remove_cvref_t; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr bool TransposeC = Traits::TransposeC; + + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; + + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); + + static constexpr bool kPadM = Traits::kPadM; + static constexpr bool kPadN = Traits::kPadN; + static constexpr bool kPadK = Traits::kPadK; + + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + + static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr index_t VectorLoadSize = Traits::_VectorSize; + + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm_problem", + concat('x', VectorLoadSize, kBlockSize), + concat('x', kPadM, kPadN, kPadK), + Scheduler); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() + { + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + if constexpr(std::is_same_v) + { + constexpr index_t pixels_per_thread = + BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; + return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType) + ? pixels_per_thread + : PackedSize * VectorLoadSize / sizeof(ADataType); + } + else + { + return VectorLoadSize / sizeof(ADataType); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB() + { + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + if constexpr(std::is_same_v) + { + constexpr index_t pixels_per_thread = + BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; + return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType) + ? pixels_per_thread + : PackedSize * VectorLoadSize / sizeof(BDataType); + } + else + { + return PackedSize * VectorLoadSize / sizeof(BDataType); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC() + { + if constexpr(std::is_same_v) + { + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size()); + constexpr index_t M0 = get_warp_size() / N2; + constexpr index_t M1 = BlockGemmShape::kM / M0; + + return std::min(M1, static_cast(VectorLoadSize / sizeof(CDataType))); + } + else + { + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size()); + constexpr index_t N0 = get_warp_size() / M2; + constexpr index_t N1 = BlockGemmShape::kN / N0; + + return std::min(N1, static_cast(VectorLoadSize / sizeof(CDataType))); + } + } + + static constexpr index_t VectorSizeA = []() { + if constexpr(std::is_same_v) + { + return kPadK ? 1 : GetAlignmentA(); + } + else + { + return kPadM ? 1 : GetAlignmentA(); + } + }(); + + static constexpr index_t VectorSizeB = []() { + if constexpr(std::is_same_v) + { + return kPadN ? 1 : GetAlignmentB(); + } + else + { + return kPadK ? 1 : GetAlignmentB(); + } + }(); + static constexpr index_t VectorSizeC = []() { + if constexpr(std::is_same_v) + { + return kPadN ? 1 : GetAlignmentC(); + } + else + { + return kPadM ? 1 : GetAlignmentC(); + } + }(); +}; + template Date: Thu, 17 Jul 2025 08:40:35 +0000 Subject: [PATCH 02/55] fix tail handler bug --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 2 +- example/ck_tile/18_flatmm/flatmm_basic.hpp | 3 +++ .../pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 12 +++++++++++- .../ops/gemm/pipeline/gemm_pipeline_problem.hpp | 6 ++++-- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 4d29b68694..c93d708910 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -80,7 +80,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c constexpr auto scheduler = FlatmmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using CodegenPipelineProblem = ck_tile::UniversalGemmPipelineProblem struct FlatmmConfig16_950 : public FlatmmConfig16 { + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType); static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128; + static constexpr int kBlockPerCu = 1; }; template diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 7e239d00a4..fddf92bfc7 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -25,9 +25,19 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; } template - CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num) { + if (TailNumber::Even == tail_num) + { + return run_func(bool_constant{}, integral_constant{}); + } + else if (TailNumber::Odd == tail_num) + { + return run_func(bool_constant{}, integral_constant{}); + } + // assert(false); return run_func(bool_constant{}, integral_constant{}); + // return run_func(bool_constant{}, integral_constant{}); } }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 449f567eb1..d3e0c06abe 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -178,11 +178,13 @@ using GemmPipelineProblem = GemmPipelineProblemBase; + template @@ -202,7 +204,7 @@ struct FlatmmPipelineProblem using CLayout = remove_cvref_t; static constexpr bool TransposeC = Traits::TransposeC; - + static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); @@ -318,7 +320,7 @@ struct FlatmmPipelineProblem return kPadM ? 1 : GetAlignmentC(); } }(); -}; +}; template Date: Tue, 22 Jul 2025 08:09:35 +0000 Subject: [PATCH 03/55] adaptive scheduler instead of Macro definition --- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 535 ++++++++++-------- 1 file changed, 298 insertions(+), 237 deletions(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index fddf92bfc7..1a2348810e 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -118,6 +118,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; + static constexpr auto warp_m = WarpTile::at(idxM); + static constexpr auto warp_n = WarpTile::at(idxN); + static constexpr auto warp_k = WarpTile::at(idxK); /* defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1 defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1 @@ -132,24 +135,74 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1 defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 */ + struct MfmaConfig + { + int mfma_per_wg; + int dsread_per_wg; + }; + static constexpr MfmaConfig GetMfmaConfig() + { - #if (defined(USING_MFMA_16x16x32_F8) || \ - defined(USING_MFMA_32x32x16_F8) || \ - defined(USING_MFMA_16x16x16_F16) || \ - defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5 - static constexpr auto mfma_per_wg = 2; - static constexpr auto dsread_per_wg = 1; - #elif (defined(USING_MFMA_16x16x32_F16) || \ - defined(USING_MFMA_32x32x16_F16) || \ - defined(USING_MFMA_16x16x128_F4) || \ - defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1 - static constexpr auto mfma_per_wg = 1; - static constexpr auto dsread_per_wg = 1; - #elif (defined(USING_MFMA_16x16x128_F8) || \ - defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2 - static constexpr auto mfma_per_wg = 1; - static constexpr auto dsread_per_wg = 2; - #endif + // K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1 + if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 16 && + std::is_same_v) || + (warp_m == 16 && warp_n == 16 && warp_k == 16 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 8 && + std::is_same_v)) + { + return {2, 1}; + } + // K1 per Mfma = 2 cases: mfma_per_wg = 1, dsread_per_wg = 2 + else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 128 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 64 && + std::is_same_v)) + { + return {1, 2}; + } + // K1 per Mfma = 1 cases: mfma_per_wg = 1, dsread_per_wg = 1 + else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 16 && + std::is_same_v) || + (warp_m == 16 && warp_n == 16 && warp_k == 128 /*&& + std::is_same_v */) || + (warp_m == 32 && warp_n == 32 && warp_k == 64 /*&& + std::is_same_v */)) + { + return {1, 1}; + } + // Default configuration + else + { + return {1, 1}; + } + } + + static constexpr auto mfma_config = GetMfmaConfig(); + static constexpr auto mfma_per_wg = mfma_config.mfma_per_wg; + static constexpr auto dsread_per_wg = mfma_config.dsread_per_wg; + + // #if (defined(USING_MFMA_16x16x32_F8) || \ + // defined(USING_MFMA_32x32x16_F8) || \ + // defined(USING_MFMA_16x16x16_F16) || \ + // defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5 + // static constexpr auto mfma_per_wg = 2; + // static constexpr auto dsread_per_wg = 1; + // #elif (defined(USING_MFMA_16x16x32_F16) || \ + // defined(USING_MFMA_32x32x16_F16) || \ + // defined(USING_MFMA_16x16x128_F4) || \ + // defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1 + // static constexpr auto mfma_per_wg = 1; + // static constexpr auto dsread_per_wg = 1; + // #elif (defined(USING_MFMA_16x16x128_F8) || \ + // defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2 + // static constexpr auto mfma_per_wg = 1; + // static constexpr auto dsread_per_wg = 2; + // #endif [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -242,252 +295,260 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // 0 M7N1: 62 - - - - // 0 M7N2: 63 - - 8 - // 0 M7N3: 64 4 - - - - - #if 0 // MI350 FP8 16X16 128*256*256 - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; + if constexpr(warp_m == 16 && warp_n == 16) + { +#if defined(__gfx950__) // MI350 FP8 16X16 128*256*256 + if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } +#else + if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128) + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 128) + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + else if(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256) + { + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - #endif - #if 0 // MI350 FP8 16X16 - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_barrier(0); - #endif - #if 0 // MI300 FP8 16X16 128*128*128 - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - }); - - __builtin_amdgcn_sched_barrier(0); - #endif - #if 0 // MI300 FP8 16X16 128*256*128 - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 4, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - }); - - __builtin_amdgcn_sched_barrier(0); - #endif - #if 0 //MI300 FP8 16X16 16*64*256 - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_barrier(0); - #endif + __builtin_amdgcn_sched_barrier(0); + } + } +#endif } - CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler() { #if 0 From 7e1bd4b83903585e8112f156277886e2754f59f6 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Wed, 23 Jul 2025 15:01:53 +0800 Subject: [PATCH 04/55] sync --- .../ck_tile/18_flatmm/run_flatmm_example.inc | 2 + .../ops/epilogue/cshuffle_epilogue.hpp | 121 +++++++++++ .../ops/flatmm/kernel/flatmm_kernel.hpp | 191 +++++++++++++++--- 3 files changed, 285 insertions(+), 29 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index b5957a7c53..bd2c154368 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -97,6 +97,8 @@ template float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, ck_tile::DeviceMem& b_shuffle_dev_buf, diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index bf58544259..fa3af14040 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -339,6 +339,127 @@ struct CShuffleEpilogue tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); + if constexpr(MemoryOperation == memory_operation_enum::set) + { + store_tile(out_dram_window, c_out_tensor); + } + else + { + update_tile(out_dram_window, c_out_tensor); + } + if constexpr(iAccess != num_access - 1) + { + constexpr auto step = SFC::get_forward_step(iAccess); + + move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); + + static_for<0, NumDTensor, 1>{}([&](auto idx) { + move_tile_window(d_dram_windows[idx], + {step.at(number<0>{}), step.at(number<1>{})}); + }); + } + }); + } + template + CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, + const OAccTile& o_acc_tile, + const DsDramWindows& ds_dram_windows, + void* p_smem, + ScaleM scale_m, + ScaleN scale_n) + { + const index_t iMWarp = get_warp_id() / kNWave; + const index_t iNWarp = get_warp_id() - iMWarp * kNWave; + const index_t iMLane = get_lane_id() / NPerXdl; + const index_t iNLane = get_lane_id() % NPerXdl; + + constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); + + auto lds_tile = make_static_distributed_tensor(LdsTileDistr); + + constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + auto o_lds_block = make_tensor_view( + static_cast(p_smem), lds_block_desc); + + auto in_lds_window = make_tile_window( + o_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + LdsTileDistr); + + auto out_lds_window = make_tile_window( + o_lds_block, + make_tuple(number{}, number{}), + {0, 0}); + + using SFC = space_filling_curve, + sequence<0, 1>, + sequence>; + constexpr index_t num_access = SFC::get_num_of_access(); + + static_assert(std::is_same_v, + "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); + + using TileEncodingPattern = + TileDistributionEncodingPattern2D; + constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + + auto d_dram_windows = generate_tuple( + [&](auto idx) { + return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); + }, + number{}); + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static_for<0, num_access, 1>{}([&](auto iAccess) { + block_sync_lds(); + constexpr auto idx_y_start = SFC::get_index(iAccess); + + constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; + constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; + + lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_lengths)); + + const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); + + store_tile(in_lds_window, c_warptile_in_tensor_casted); + block_sync_lds(); + + auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); + + auto m1 = iMLane; + float scale_B = scale_n[nIter * NPerIterationShuffle]; + static_for<0, kM0, 1>{}([&](auto m0) { + static_for<0, kM2, 1>{}([&](auto m2) { + float scale_A = scale_m[mIter * MPerIterationShuffle + iMWarp * MPerXdl + + m0 * kM1 * kM2 + m1 * kM2 + m2]; + c_out_tensor.get_thread_buffer()[m0 * kM2 + m2] *= scale_A * scale_B; + }); + }); + + const auto ds_tensor = generate_tuple( + [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); + + const auto c_ds_tiles = concat_tuple_of_reference( + tie(c_out_tensor, c_out_tensor), + generate_tie( + [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + + tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 607645c097..653d4ed431 100755 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -11,12 +11,97 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { - -template -struct FlatmmHostArgs +struct FlatmmProblem { - CK_TILE_HOST FlatmmHostArgs() = default; - CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_, + CK_TILE_HOST FlatmmProblem() = default; + CK_TILE_HOST FlatmmProblem( + index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) + : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) + { + } + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; +}; + +template +struct FlatmmScalePointer +{ + static constexpr int granularity = SharedGranularity; + + union + { + const float* ptr; + float scalar; // if shared granularity is 0, all rows/columns use the same scale value + }; + + CK_TILE_HOST_DEVICE FlatmmScalePointer() = default; + CK_TILE_HOST_DEVICE FlatmmScalePointer(float scalar_) : scalar(scalar_) {} + CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {} + + CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const + { + FlatmmScalePointer ret; + if constexpr(granularity == 0) + { + ret.scalar = scalar; + } + else if constexpr(granularity == 1) + { + ret.ptr = ptr + offset; + } + else + { + ret.ptr = ptr + offset / granularity; + } + return ret; + } + + CK_TILE_HOST_DEVICE float operator[](index_t i) const + { + if constexpr(granularity == 0) + { + return scalar; + } + else if constexpr(granularity == 1) + { + return ptr[i]; + } + else + { + return ptr[i / granularity]; + } + } +}; +// shared granularity = -1 means no scale +template <> +struct FlatmmScalePointer<-1> +{ + static constexpr int granularity = -1; + + CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default; + CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(float scalar_) {} + CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float* ptr_) {} + + CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const + { + return FlatmmScalePointer{}; + } + CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const + { + return 1; // alway return 1, it doesn't change the result + } +}; + +template <> +struct BaseFlatmmHostArgs +{ + CK_TILE_HOST BaseFlatmmHostArgs() = default; + CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_, const void* b_ptr_, const std::array& ds_ptr_, void* e_ptr_, @@ -66,7 +151,37 @@ struct FlatmmHostArgs index_t k_batch; }; -template +template , class ScaleN = FlatmmScalePointer<-1>, index_t NumDTensor = 0> +struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<> +{ + CK_TILE_HOST ScaleFlatmmHostArgs() = default; + CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_, + const void* b_shuffle_ptr_, + const std::array& ds_ptr_, + void* c_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + const std::array& stride_Ds_, + index_t stride_C_, + ScaleM scale_m_ = nullptr, + ScaleN scale_n_ = nullptr) + : BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_, k_batch_), + scale_m(scale_m_), + scale_n(scale_n_) + { + } + ScaleM scale_m = nullptr; + ScaleN scale_n = nullptr; +}; + +template +using FlatmmHostArgs = ScaleFlatmmHostArgs, FlatmmScalePointer<-1>, NumberTensor>; + +template struct FlatmmKernelArgs { const void* a_ptr; @@ -82,6 +197,8 @@ struct FlatmmKernelArgs std::array stride_Ds; index_t stride_E; index_t k_batch; + ScaleM scale_m_ptr = nullptr; + ScaleN scale_n_ptr = nullptr; }; template @@ -113,7 +230,7 @@ struct FlatmmKernel static_assert(DsLayout::size() == DsDataType::size(), "The size of DsLayout and DsDataType should be the same"); - using KernelArgs = FlatmmKernelArgs; + // using KernelArgs = FlatmmKernelArgs; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -129,21 +246,24 @@ struct FlatmmKernel CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - CK_TILE_HOST static constexpr KernelArgs - MakeKernelArgs(const FlatmmHostArgs& hostArgs) + template + CK_TILE_HOST static constexpr FlatmmKernelArgs + MakeKernelArgs(const FlatmmHostArgs& hostArgs) { - return KernelArgs{hostArgs.a_ptr, - hostArgs.b_ptr, - hostArgs.ds_ptr, - hostArgs.e_ptr, - hostArgs.M, - hostArgs.N, - hostArgs.K, - hostArgs.stride_A, - hostArgs.stride_B, - hostArgs.stride_Ds, - hostArgs.stride_E, - hostArgs.k_batch}; + return {hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_Ds, + hostArgs.stride_E, + hostArgs.k_batch, + hostArgs.scale_m, + hostArgs.scale_n}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize() @@ -157,8 +277,8 @@ struct FlatmmKernel struct SplitKBatchOffset { - __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) - { + template + __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = kargs.k_batch * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; @@ -196,6 +316,7 @@ struct FlatmmKernel index_t splitted_k; }; + template CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) { if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && @@ -341,7 +462,7 @@ struct FlatmmKernel return DTesnorIsValid; } - template + template CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, const BDataType* b_flat_ptr, @@ -559,14 +680,14 @@ struct FlatmmKernel return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window); } - template + template CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, const BDataType* b_flat_ptr, const std::array& ds_ptr, EDataType* e_ptr, void* smem_ptr_ping, void* smem_ptr_pong, - const KernelArgs& kargs, + const FlatmmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -588,8 +709,18 @@ struct FlatmmKernel a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong); // Run Epilogue Pipeline - - if(UseDefaultScheduler || (get_warp_id() == 0)) + if constexpr(ScaleM::granularity != -1 || ScaleN::granularity != -1) + { + auto& c_block_window = gemm_tile_windows.at(I3); + EpiloguePipeline{}.template operator()( + c_block_window, + c_block_tile, + d_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } + else if(UseDefaultScheduler || (get_warp_id() == 0)) { // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); @@ -598,7 +729,9 @@ struct FlatmmKernel } } - CK_TILE_DEVICE void operator()(KernelArgs kargs) const + template + CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs, + int partition_idx = blockIdx.x) const { const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); From 6dacf833da2fd12855541a68d0dd2c30efa42127 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Wed, 23 Jul 2025 07:20:26 +0000 Subject: [PATCH 05/55] fix bug --- .../flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 1a2348810e..46f4deb01d 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -545,10 +545,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_barrier(0); } - } #endif + } } + CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler() { #if 0 From 3f7d848dd35fea7cd3978073bda94cdefbf7d81d Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Wed, 23 Jul 2025 15:38:12 +0800 Subject: [PATCH 06/55] build pass --- .../ck_tile/18_flatmm/run_flatmm_example.inc | 2 - .../ops/epilogue/cshuffle_epilogue.hpp | 178 +++++++++--------- .../ops/flatmm/kernel/flatmm_kernel.hpp | 8 +- 3 files changed, 93 insertions(+), 95 deletions(-) diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index bd2c154368..b5957a7c53 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -97,8 +97,6 @@ template float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, ck_tile::DeviceMem& b_shuffle_dev_buf, diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index fa3af14040..bf5539f702 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -368,118 +368,118 @@ struct CShuffleEpilogue ScaleM scale_m, ScaleN scale_n) { - const index_t iMWarp = get_warp_id() / kNWave; - const index_t iNWarp = get_warp_id() - iMWarp * kNWave; - const index_t iMLane = get_lane_id() / NPerXdl; - const index_t iNLane = get_lane_id() % NPerXdl; + // const index_t iMWarp = get_warp_id() / kNWave; + // const index_t iNWarp = get_warp_id() - iMWarp * kNWave; + // const index_t iMLane = get_lane_id() / NPerXdl; + // const index_t iNLane = get_lane_id() % NPerXdl; - constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); + // constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); - auto lds_tile = make_static_distributed_tensor(LdsTileDistr); + // auto lds_tile = make_static_distributed_tensor(LdsTileDistr); - constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); - auto o_lds_block = make_tensor_view( - static_cast(p_smem), lds_block_desc); + // constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + // auto o_lds_block = make_tensor_view( + // static_cast(p_smem), lds_block_desc); - auto in_lds_window = make_tile_window( - o_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - LdsTileDistr); + // auto in_lds_window = make_tile_window( + // o_lds_block, + // make_tuple(number{}, number{}), + // {0, 0}, + // LdsTileDistr); - auto out_lds_window = make_tile_window( - o_lds_block, - make_tuple(number{}, number{}), - {0, 0}); + // auto out_lds_window = make_tile_window( + // o_lds_block, + // make_tuple(number{}, number{}), + // {0, 0}); - using SFC = space_filling_curve, - sequence<0, 1>, - sequence>; - constexpr index_t num_access = SFC::get_num_of_access(); + // using SFC = space_filling_curve, + // sequence<0, 1>, + // sequence>; + // constexpr index_t num_access = SFC::get_num_of_access(); - static_assert(std::is_same_v, - "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); + // static_assert(std::is_same_v, + // "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); - using TileEncodingPattern = - TileDistributionEncodingPattern2D; - constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + // using TileEncodingPattern = + // TileDistributionEncodingPattern2D; + // constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); - auto d_dram_windows = generate_tuple( - [&](auto idx) { - return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); - }, - number{}); + // auto d_dram_windows = generate_tuple( + // [&](auto idx) { + // return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); + // }, + // number{}); - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + // constexpr auto c_warp_y_lengths = + // to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + // constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - static_for<0, num_access, 1>{}([&](auto iAccess) { - block_sync_lds(); - constexpr auto idx_y_start = SFC::get_index(iAccess); + // static_for<0, num_access, 1>{}([&](auto iAccess) { + // block_sync_lds(); + // constexpr auto idx_y_start = SFC::get_index(iAccess); - constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; - constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; + // constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; + // constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; - lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); + // lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + // merge_sequences( + // sequence{}, + // c_warp_y_index_zeros), + // merge_sequences(sequence{}, + // c_warp_y_lengths)); - const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); + // const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); - store_tile(in_lds_window, c_warptile_in_tensor_casted); - block_sync_lds(); + // store_tile(in_lds_window, c_warptile_in_tensor_casted); + // block_sync_lds(); - auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); + // auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - auto m1 = iMLane; - float scale_B = scale_n[nIter * NPerIterationShuffle]; - static_for<0, kM0, 1>{}([&](auto m0) { - static_for<0, kM2, 1>{}([&](auto m2) { - float scale_A = scale_m[mIter * MPerIterationShuffle + iMWarp * MPerXdl + - m0 * kM1 * kM2 + m1 * kM2 + m2]; - c_out_tensor.get_thread_buffer()[m0 * kM2 + m2] *= scale_A * scale_B; - }); - }); + // auto m1 = iMLane; + // float scale_B = scale_n[nIter * NPerIterationShuffle]; + // static_for<0, kM0, 1>{}([&](auto m0) { + // static_for<0, kM2, 1>{}([&](auto m2) { + // float scale_A = scale_m[mIter * MPerIterationShuffle + iMWarp * MPerXdl + + // m0 * kM1 * kM2 + m1 * kM2 + m2]; + // c_out_tensor.get_thread_buffer()[m0 * kM2 + m2] *= scale_A * scale_B; + // }); + // }); - const auto ds_tensor = generate_tuple( - [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); + // const auto ds_tensor = generate_tuple( + // [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); - const auto c_ds_tiles = concat_tuple_of_reference( - tie(c_out_tensor, c_out_tensor), - generate_tie( - [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + // const auto c_ds_tiles = concat_tuple_of_reference( + // tie(c_out_tensor, c_out_tensor), + // generate_tie( + // [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); - tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); + // tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile(out_dram_window, c_out_tensor); - } - else - { - update_tile(out_dram_window, c_out_tensor); - } - if constexpr(iAccess != num_access - 1) - { - constexpr auto step = SFC::get_forward_step(iAccess); + // if constexpr(MemoryOperation == memory_operation_enum::set) + // { + // store_tile(out_dram_window, c_out_tensor); + // } + // else + // { + // update_tile(out_dram_window, c_out_tensor); + // } + // if constexpr(iAccess != num_access - 1) + // { + // constexpr auto step = SFC::get_forward_step(iAccess); - move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); + // move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); - static_for<0, NumDTensor, 1>{}([&](auto idx) { - move_tile_window(d_dram_windows[idx], - {step.at(number<0>{}), step.at(number<1>{})}); - }); - } - }); + // static_for<0, NumDTensor, 1>{}([&](auto idx) { + // move_tile_window(d_dram_windows[idx], + // {step.at(number<0>{}), step.at(number<1>{})}); + // }); + // } + // }); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 653d4ed431..eee2eeb769 100755 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -97,7 +97,7 @@ struct FlatmmScalePointer<-1> } }; -template <> +template struct BaseFlatmmHostArgs { CK_TILE_HOST BaseFlatmmHostArgs() = default; @@ -169,7 +169,7 @@ struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<> index_t stride_C_, ScaleM scale_m_ = nullptr, ScaleN scale_n_ = nullptr) - : BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_, k_batch_), + : BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_), scale_m(scale_m_), scale_n(scale_n_) { @@ -248,7 +248,7 @@ struct FlatmmKernel template CK_TILE_HOST static constexpr FlatmmKernelArgs - MakeKernelArgs(const FlatmmHostArgs& hostArgs) + MakeKernelArgs(const ScaleFlatmmHostArgs& hostArgs) { return {hostArgs.a_ptr, hostArgs.b_ptr, @@ -754,7 +754,7 @@ struct FlatmmKernel is_any_of::value)) { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); - RunFlatmm(a_ptr, + RunFlatmm(a_ptr, b_flat_ptr, kargs.ds_ptr, e_ptr, From 89fa639207b78c63505560d013692ada7cd8a6fd Mon Sep 17 00:00:00 2001 From: valarLip <340077269@qq.com> Date: Wed, 23 Jul 2025 08:44:12 +0000 Subject: [PATCH 07/55] merge flatmm pipe v0 from dteng_flatmm_opt --- example/ck_tile/18_flatmm/CMakeLists.txt | 12 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 4 +- example/ck_tile/18_flatmm/flatmm_basic.hpp | 192 ++-- include/ck_tile/ops/flatmm.hpp | 1 + .../flatmm_pipeline_agmem_bgmem_creg_v0.hpp | 883 ++++++++++++++++++ 5 files changed, 987 insertions(+), 105 deletions(-) create mode 100644 include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 87237458c5..30fd769c88 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -1,15 +1,13 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) + +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef -Wno-unused-variable -Wno-unused-parameter) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32_F8=1 -Wno-unused-local-typedef) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16_F8=1 -Wno-unused-local-typedef) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x128_F8=1 -Wno-unused-local-typedef) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0") + +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-nrvo) +#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0") #list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1") target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index c93d708910..6f977a803c 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -63,7 +63,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; + using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV0; const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; @@ -90,7 +90,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c tail_number_v>; using CodegenFlatmmPipeline = - ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + ck_tile::FlatmmPipelineAGmemBGmemCRegV0; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType); static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128; - static constexpr int kBlockPerCu = 1; + static constexpr int kBlockPerCu = 2; }; template @@ -167,119 +167,119 @@ struct is_8bit_type { }; -template -struct GemmConfig -{ -#if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16 - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; +// template +// struct GemmConfig +// { +// #if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16 +// static constexpr ck_tile::index_t M_Tile = 128; +// static constexpr ck_tile::index_t N_Tile = 256; +// static constexpr ck_tile::index_t K_Tile = 256; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; +// static constexpr ck_tile::index_t M_Warp = 1; +// static constexpr ck_tile::index_t N_Warp = 4; +// static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; -#elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; +// static constexpr ck_tile::index_t M_Warp_Tile = 16; +// static constexpr ck_tile::index_t N_Warp_Tile = 16; +// static constexpr ck_tile::index_t K_Warp_Tile = 128; +// #elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune) +// static constexpr ck_tile::index_t M_Tile = 128; +// static constexpr ck_tile::index_t N_Tile = 128; +// static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; +// static constexpr ck_tile::index_t M_Warp = 1; +// static constexpr ck_tile::index_t N_Warp = 4; +// static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 64; -#elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; +// static constexpr ck_tile::index_t M_Warp_Tile = 32; +// static constexpr ck_tile::index_t N_Warp_Tile = 32; +// static constexpr ck_tile::index_t K_Warp_Tile = 64; +// #elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune) +// static constexpr ck_tile::index_t M_Tile = 128; +// static constexpr ck_tile::index_t N_Tile = 128; +// static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; +// static constexpr ck_tile::index_t M_Warp = 1; +// static constexpr ck_tile::index_t N_Warp = 4; +// static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; -#elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; +// static constexpr ck_tile::index_t M_Warp_Tile = 16; +// static constexpr ck_tile::index_t N_Warp_Tile = 16; +// static constexpr ck_tile::index_t K_Warp_Tile = 32; +// #elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune) +// static constexpr ck_tile::index_t M_Tile = 128; +// static constexpr ck_tile::index_t N_Tile = 128; +// static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; +// static constexpr ck_tile::index_t M_Warp = 1; +// static constexpr ck_tile::index_t N_Warp = 4; +// static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; -#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16 - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256; +// static constexpr ck_tile::index_t M_Warp_Tile = 32; +// static constexpr ck_tile::index_t N_Warp_Tile = 32; +// static constexpr ck_tile::index_t K_Warp_Tile = 16; +// #elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16 +// static constexpr ck_tile::index_t M_Tile = 16; +// static constexpr ck_tile::index_t N_Tile = 64; +// static constexpr ck_tile::index_t K_Tile = 256; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; +// static constexpr ck_tile::index_t M_Warp = 1; +// static constexpr ck_tile::index_t N_Warp = 4; +// static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; -#elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128; +// static constexpr ck_tile::index_t M_Warp_Tile = 16; +// static constexpr ck_tile::index_t N_Warp_Tile = 16; +// static constexpr ck_tile::index_t K_Warp_Tile = 64; +// #elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune) +// static constexpr ck_tile::index_t M_Tile = 128; +// static constexpr ck_tile::index_t N_Tile = 256; +// static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 8; - static constexpr ck_tile::index_t K_Warp = 1; +// static constexpr ck_tile::index_t M_Warp = 1; +// static constexpr ck_tile::index_t N_Warp = 8; +// static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 32; -#elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; +// static constexpr ck_tile::index_t M_Warp_Tile = 32; +// static constexpr ck_tile::index_t N_Warp_Tile = 32; +// static constexpr ck_tile::index_t K_Warp_Tile = 32; +// #elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune) +// static constexpr ck_tile::index_t M_Tile = 128; +// static constexpr ck_tile::index_t N_Tile = 128; +// static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; +// static constexpr ck_tile::index_t M_Warp = 1; +// static constexpr ck_tile::index_t N_Warp = 4; +// static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; -#elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; +// static constexpr ck_tile::index_t M_Warp_Tile = 16; +// static constexpr ck_tile::index_t N_Warp_Tile = 16; +// static constexpr ck_tile::index_t K_Warp_Tile = 32; +// #elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune) +// static constexpr ck_tile::index_t M_Tile = 128; +// static constexpr ck_tile::index_t N_Tile = 128; +// static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; +// static constexpr ck_tile::index_t M_Warp = 1; +// static constexpr ck_tile::index_t N_Warp = 4; +// static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; -#else - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; +// static constexpr ck_tile::index_t M_Warp_Tile = 32; +// static constexpr ck_tile::index_t N_Warp_Tile = 32; +// static constexpr ck_tile::index_t K_Warp_Tile = 16; +// #else +// static constexpr ck_tile::index_t M_Tile = 128; +// static constexpr ck_tile::index_t N_Tile = 256; +// static constexpr ck_tile::index_t K_Tile = 256; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; +// static constexpr ck_tile::index_t M_Warp = 1; +// static constexpr ck_tile::index_t N_Warp = 4; +// static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; -#endif -}; +// static constexpr ck_tile::index_t M_Warp_Tile = 16; +// static constexpr ck_tile::index_t N_Warp_Tile = 16; +// static constexpr ck_tile::index_t K_Warp_Tile = 128; +// #endif +// }; auto create_args(int argc, char* argv[]) { diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 1714789e63..7a69eb3e98 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp" #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" #include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp new file mode 100644 index 0000000000..c4a16a121f --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp @@ -0,0 +1,883 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" + +namespace ck_tile { + +template +struct BaseFlatmmPipelineAGmemBGmemCRegV0 +{ + static constexpr index_t PrefetchStages = 2; + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num) + { + if (TailNumber::Even == tail_num) + { + return run_func(bool_constant{}, integral_constant{}); + } + else if (TailNumber::Odd == tail_num) + { + return run_func(bool_constant{}, integral_constant{}); + } + // assert(false); + return run_func(bool_constant{}, integral_constant{}); + // return run_func(bool_constant{}, integral_constant{}); + } +}; + +template +struct FlatmmPipelineAGmemBGmemCRegV0 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockFlatmm = + remove_cvref_t())>; + + static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; + static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + + static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } + static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr index_t kLdsAlignmentInBytes = 16; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + static constexpr index_t K1 = 16 / sizeof(ADataType); + static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; + static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp; + static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum; + static constexpr index_t BloadGap = MIterPerWarp / 2; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + + static constexpr auto warp_m = WarpTile::at(idxM); + static constexpr auto warp_n = WarpTile::at(idxN); + static constexpr auto warp_k = WarpTile::at(idxK); + + /* + defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1 + defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1 + defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1 + defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1 + + defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1 + defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1 + defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1 + defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1 + + defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1 + defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 + */ + struct MfmaConfig + { + int mfma_per_wg; + int dsread_per_wg; + }; + static constexpr MfmaConfig GetMfmaConfig() + { + + // K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1 + if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 16 && + std::is_same_v) || + (warp_m == 16 && warp_n == 16 && warp_k == 16 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 8 && + std::is_same_v)) + { + return {2, 1}; + } + // K1 per Mfma = 2 cases: mfma_per_wg = 1, dsread_per_wg = 2 + else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 128 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 64 && + std::is_same_v)) + { + return {1, 2}; + } + // K1 per Mfma = 1 cases: mfma_per_wg = 1, dsread_per_wg = 1 + else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 16 && + std::is_same_v) || + (warp_m == 16 && warp_n == 16 && warp_k == 128 /*&& + std::is_same_v */) || + (warp_m == 32 && warp_n == 32 && warp_k == 64 /*&& + std::is_same_v */)) + { + return {1, 1}; + } + // Default configuration + else + { + return {1, 1}; + } + } + + static constexpr auto mfma_config = GetMfmaConfig(); + static constexpr auto mfma_per_wg = mfma_config.mfma_per_wg; + static constexpr auto dsread_per_wg = mfma_config.dsread_per_wg; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV1", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return PipelinePolicy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + // Keypoint of pipeline optimize is workload balance in time + // instruction schedule example(128X256X256, 1X4, 16X16X128): + // Iter MNK MFMA ds_read ds_write A_load b_load + // -1 M6N0: 57 - 8 - - + // -1 M6N1: 58 1 - - - + // -1 M6N2: 59 - - 7 - + // -1 M6N3: 60 2 - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 3 - - - + // -1 M7N2: 63 - - 8 - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - - + // 0 M0N1: 2 5 - - 2 + // 0 M0N2: 3 - - - - + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - - + // 0 M1N1: 6 7 - - 4 + // 0 M1N2: 7 - - - - + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - - + // 0 M2N1: 10 9 - - 6 + // 0 M2N2: 11 - - - - + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - - + // 0 M3N1: 14 11 - - 8 + // 0 M3N2: 15 - - - - + // 0 M3N3: 16 12 - - - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 13 - - - + // 0 M4N2: 19 - - 1 - + // 0 M4N3: 20 14 - - - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 15 - - - + // 0 M5N2: 23 - - 2 - + // 0 M5N3: 24 16 - - - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 17 - - - + // 0 M6N2: 27 - - 3 - + // 0 M6N3: 28 18 - - - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 19 - - - + // 0 M7N2: 31 - - 4 - + // 0 M7N3: 32 20 - - - + // 0 M0N0K1: 33 - - - - + // 0 M0N1: 34 21 - - 10 + // 0 M0N2: 35 - - - - + // 0 M0N3: 36 22 - - - + // 0 M1N0: 37 - - - - + // 0 M1N1: 38 23 - - 12 + // 0 M1N2: 39 - - - - + // 0 M1N3: 40 24 - - - + // 0 M2N0: 41 - - - - + // 0 M2N1: 42 25 - - 14 + // 0 M2N2: 43 - - - - + // 0 M2N3: 44 26 - - - + // 0 M3N0: 45 - 5 - - + // 0 M3N1: 46 27 - - 16 + // 0 M3N2: 47 - - - - + // 0 M3N3: 48 28 - - - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 29 - - - + // 0 M4N2: 51 - - 5 - + // 0 M4N3: 52 30 - - - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 31 - - - + // 0 M5N2: 55 - - 6 - + // 0 M5N3: 56 32 - - - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 1 - - - + // 0 M6N2: 59 - - 7 - + // 0 M6N3: 60 2 - - - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 3 - - - + // 0 M7N2: 63 - - 8 - + // 0 M7N3: 64 4 - - - + + #if 0 + constexpr auto dsread_num_perK = dsread_per_wg * MIterPerWarp; + constexpr auto dswrite_num_perK = (dsread_num_perK + MWarp * NWarp - 1) / (MWarp * NWarp); + constexpr auto dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; + + // index_t dsread_perM[MIterPerWarp]; + // index_t dswrite_perM[MIterPerWarp]; + index_t dsread_perM[MIterPerWarp]; + index_t dswrite_perM[MIterPerWarp]; + index_t load_perM[MIterPerWarp]; + + constexpr int dswrite_inst = dswrite_num_perK; + constexpr int NIter_num = NIterPerWarp*mfma_per_wg; + + #pragma unroll + for(int i=0;i 0 ? dswrite_inst - MIterPerWarp + 2 : 0; + } + else if(i==MIterPerWarp-1) + { + dswrite_perM[MIterPerWarp-1] = 0; + } + else + { + dswrite_perM[i] = (i + 2 - dswrite_inst) > 0 ? 1 : 0; + } + } + + #pragma unroll + for(int i=0;i<4;i++) + { + load_perM[i] = 2; + } + + #pragma unroll + for(int i=4;i<8;i++) + { + load_perM[i] = 1; + } + + #pragma unroll + for(int i=0;i load_perM[i] ? (dsread_perM[i] > dswrite_perM[i] ? dsread_perM[i] : dswrite_perM[i]) : (load_perM[i] > dswrite_perM[i] ? load_perM[i] : dswrite_perM[i]); + int total_num = dsread_perM[i] + load_perM[i] + dswrite_perM[i]; + int gap = (total_num+NIter_num-1)/NIter_num; + + index_t inst_order[MIterPerWarp*10]; + #pragma unroll + for(int j=0;jj) + { + inst_order[index] = 1; + index++; + } + if(load_perM[i]>j) + { + inst_order[index] = 2; + index++; + } + if(dsread_perM[i]>j) + { + inst_order[index] = 3; + index++; + } + } + + #pragma unroll + for(int j=0;j{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + } + + + CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler() + { + #if 0 + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_barrier(0); + #endif + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + const index_t iMWarp = get_warp_id() / NWarp; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + ADataType* p_a_lds_ping = static_cast(p_smem_ping); + ADataType* p_a_lds_pong = static_cast(p_smem_pong); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block_ping = make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = make_tile_window( + a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + auto a_warp_window_pong_tmp = make_tile_window( + a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_flatmm = BlockFlatmm(); + // Acc register tile + auto c_block_tile = block_flatmm.MakeCBlockTile(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + // pingpong buffer for B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_ping; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_pong; + + + // Prefetch A0 + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + + index_t iCounter = (num_loop - 1) / 2; + while(iCounter > 0) + { + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + auto a_warp_tensor_ping = load_tile(a_warp_windows_ping(mIter)(kIter)); + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_ping, b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + block_sync_lds(); + + HotLoopScheduler(); + + //Next K + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i+1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + auto a_warp_tensor_pong = load_tile(a_warp_windows_pong(mIter)(kIter)); + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_pong, b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + block_sync_lds(); + + HotLoopScheduler(); + + iCounter--; + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // GEMM loopK-1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + auto a_warp_tensor_ping = load_tile(a_warp_windows_ping(mIter)(kIter)); + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_ping, b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + block_sync_lds(); + TailHotLoopScheduler(); + + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + auto a_warp_tensor_pong = load_tile(a_warp_windows_pong(mIter)(kIter)); + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_pong, b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + auto a_warp_tensor_ping = load_tile(a_warp_windows_ping(mIter)(kIter)); + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_ping, b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + num_loop, + p_smem_ping, + p_smem_pong); + } +}; + +} // namespace ck_tile From 5a1183ebbdece8e69e7ebbf5b95e8cde40937d3a Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Wed, 23 Jul 2025 19:04:22 +0000 Subject: [PATCH 08/55] support flatmm scaling --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 44 +++- example/ck_tile/18_flatmm/flatmm_basic.hpp | 121 +--------- .../ck_tile/18_flatmm/run_flatmm_example.inc | 133 ++++++++--- include/ck_tile/core/container/sequence.hpp | 8 +- .../ck_tile/host/reference/reference_gemm.hpp | 144 ++++++++++++ .../ops/epilogue/cshuffle_epilogue.hpp | 210 ++++++++++-------- .../ops/flatmm/kernel/flatmm_kernel.hpp | 134 ++++++----- 7 files changed, 476 insertions(+), 318 deletions(-) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index c93d708910..7c6559ba2a 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -23,9 +23,12 @@ template -float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s) +float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s) { using CodegenFlatmmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -81,13 +84,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c constexpr auto memory_operation = memory_operation_.value; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem; + BDataType, + AccDataType, + CodegenFlatmmShape, + CodegenGemmTraits, + scheduler, + has_hot_loop_v, + tail_number_v>; using CodegenFlatmmPipeline = ck_tile::FlatmmPipelineAGmemBGmemCRegV1; @@ -217,6 +220,7 @@ int run_flatmm_example(int argc, char* argv[]) std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); + int scale_opt = arg_parser.get_int("scale"); if(a_layout == "R" && b_layout == "C") { if(data_type == "fp16") @@ -231,13 +235,29 @@ int run_flatmm_example(int argc, char* argv[]) } else if(data_type == "fp8") { - run_flatmm_example_with_layouts>( - argc, argv, Row{}, Col{}, Row{}); + if(scale_opt == 0) + { + run_flatmm_example_with_layouts>( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + run_flatmm_example_with_layouts, 1, 1>( + argc, argv, Row{}, Col{}, Row{}); + } } else if(data_type == "bf8") { - run_flatmm_example_with_layouts>( - argc, argv, Row{}, Col{}, Row{}); + if(scale_opt == 0) + { + run_flatmm_example_with_layouts>( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + run_flatmm_example_with_layouts, 1, 1>( + argc, argv, Row{}, Col{}, Row{}); + } } else { diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 38858ecde8..2b94325f6c 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -83,10 +83,10 @@ struct FlatmmConfig16 template struct FlatmmConfig16_950 : public FlatmmConfig16 { - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType); + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType); static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128; - static constexpr int kBlockPerCu = 1; + static constexpr int kBlockPerCu = 1; }; template @@ -167,120 +167,6 @@ struct is_8bit_type { }; -template -struct GemmConfig -{ -#if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16 - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; -#elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 64; -#elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; -#elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; -#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16 - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; -#elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 8; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 32; -#elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 32; -#elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune) - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; -#else - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 256; - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; -#endif -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -301,6 +187,7 @@ auto create_args(int argc, char* argv[]) .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8") .insert("warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index b5957a7c53..e48cf43448 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -18,7 +18,7 @@ constexpr const char* DataTypeToString() { return "bf8"; } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return "bf16"; } @@ -83,9 +83,12 @@ template -float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s); +float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s); template float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, ck_tile::DeviceMem& b_shuffle_dev_buf, @@ -108,21 +113,25 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, ck_tile::index_t stride_B, ck_tile::index_t stride_C, ck_tile::index_t kbatch, + ScaleM scale_m, + ScaleN scale_n, int n_warmup, int n_repeat) { - ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(), - b_shuffle_dev_buf.GetDeviceBuffer(), - {}, - c_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - {}, - stride_C}; + ck_tile::ScaleFlatmmHostArgs args = {a_dev_buf.GetDeviceBuffer(), + b_shuffle_dev_buf.GetDeviceBuffer(), + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C, + scale_m, + scale_n}; float ave_time = flatmm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); @@ -154,6 +165,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, template @@ -197,21 +210,30 @@ int run_flatmm_example_with_layouts(int argc, ck_tile::HostTensor c_rslt_host( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::HostTensor per_token_scale(ck_tile::HostTensorDescriptor({M}, {1})); + ck_tile::HostTensor per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1})); + // TODO: add different init types if(init_method == 0) { ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(per_token_scale); + ck_tile::FillUniformDistribution{-1.f, 1.f}(per_channel_scale); } else if(init_method == 1) { ck_tile::FillMonotonicSeq{}(a_host); ck_tile::FillMonotonicSeq{}(b_origin_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(per_token_scale); + ck_tile::FillUniformDistribution{1.f, 1.f}(per_channel_scale); } else if(init_method == 2) { ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(per_token_scale); + ck_tile::FillUniformDistribution{1.f, 1.f}(per_channel_scale); } else { @@ -222,14 +244,25 @@ int run_flatmm_example_with_layouts(int argc, ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes()); + ck_tile::DeviceMem per_channel_scale_dev_buf( + per_channel_scale.get_element_space_size_in_bytes()); + a_dev_buf.ToDevice(a_host.data()); c_rslt_host.SetZero(); + per_token_scale_dev_buf.ToDevice(per_token_scale.data()); + per_channel_scale_dev_buf.ToDevice(per_channel_scale.data()); // do pre-shuffle ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host); ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); + auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer{ + static_cast(per_token_scale_dev_buf.GetDeviceBuffer())}; + auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer{ + static_cast(per_channel_scale_dev_buf.GetDeviceBuffer())}; + invoke_flatmm, - CLayout>(a_dev_buf, - b_shuffle_dev_buf, - c_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat); + CLayout, + decltype(per_token_scale_dev_ptr), + decltype(per_channel_scale_dev_ptr)>(a_dev_buf, + b_shuffle_dev_buf, + c_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + per_token_scale_dev_ptr, + per_channel_scale_dev_ptr, + n_warmup, + n_repeat); c_dev_buf.FromDevice(c_rslt_host.data()); bool pass = true; @@ -263,6 +300,8 @@ int run_flatmm_example_with_layouts(int argc, if(arg_parser.get_int("v") == 1) { + assert(ScaleGranularityM == -1 && ScaleGranularityN == -1 && + "ScaleAB is not supported for CPU verification!"); ck_tile::HostTensor c_ref_host( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_ref_host.SetZero(); @@ -310,13 +349,41 @@ int run_flatmm_example_with_layouts(int argc, N * K * sizeof(BDataType), hipMemcpyHostToDevice)); - ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1) + { + ck_tile::reference_gemm_gpu( + d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + } + else + { + ck_tile::reference_blockwise_gemm_gpu( + d_A, + d_B, + d_C, + M, + N, + K, + stride_A, + stride_B, + stride_C, + ScaleGranularityM, + ScaleGranularityN, + K, + static_cast(per_token_scale_dev_buf.GetDeviceBuffer()), + static_cast(per_channel_scale_dev_buf.GetDeviceBuffer())); + } ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(), d_C, diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index b187b71830..6114716111 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -165,6 +165,9 @@ struct sequence return sequence{}; } + CK_TILE_HOST_DEVICE static constexpr auto sum() { return (Is + ... + 0); } + CK_TILE_HOST_DEVICE static constexpr auto product() { return (Is * ... * 1); } + // pickup element at index template CK_TILE_HOST_DEVICE static constexpr auto extract(number...) @@ -1236,9 +1239,8 @@ constexpr auto reverse_slice_sequence(Seq, template ::type> -constexpr auto slice_sequence(Seq, - number, - Mask = typename uniform_sequence_gen::type{}) +constexpr auto +slice_sequence(Seq, number, Mask = typename uniform_sequence_gen::type{}) { constexpr auto r = reverse_slice_sequence(Seq{}.reverse(), number{}, Mask{}.reverse()); diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index c88deaec01..351c04543a 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -195,6 +195,104 @@ __global__ void naive_gemm_kernel(ADataType* A, } } +template +__global__ void blockwise_gemm_kernel(ADataType* A, + BDataType* B, + CDataType* C, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t strideA, + ck_tile::index_t strideB, + ck_tile::index_t strideC, + ck_tile::index_t scale_granularity_m, + ck_tile::index_t scale_granularity_n, + ck_tile::index_t scale_granularity_k, + float* scale_A_ptr, + float* scale_B_ptr) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int row = idx / N; // Compute row index + int col = idx % N; // Compute column index + + if(row < M && col < N) + { + AccDataType acc = 0.0, acc_temp = 0.0; + + index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m; + index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n; + + float scale_A = 0; + float scale_B = 0; + + for(int k = 0; k < K; ++k) + { + if(k % scale_granularity_k == 0) + { + // update acc + acc += acc_temp * scale_A * scale_B; + acc_temp = 0.0; + // update scale factors + scale_A = scale_A_ptr[(row / scale_granularity_m) + + (k / scale_granularity_k) * scale_A_stride]; + scale_B = scale_B_ptr[(col / scale_granularity_n) + + (k / scale_granularity_k) * scale_B_stride]; + } + + constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize; + constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize; + // Adjust indexing based on matrix layout + int a_index = (std::is_same_v) + ? row * strideA + k + : k * strideA + row; + int b_index = (std::is_same_v) + ? col * strideB + k + : k * strideB + col; + + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(A[a_index]); + } + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else + { + v_b = ck_tile::type_convert(B[b_index]); + } + acc_temp += v_a * v_b; + } + // final accumulation + acc += acc_temp * scale_A * scale_B; + + int c_index = (std::is_same_v) + ? row * strideC + col + : col * strideC + row; + C[c_index] = ck_tile::type_convert(acc); + } +} + template +void reference_blockwise_gemm_gpu(ADataType* a_ptr, + BDataType* b_ptr, + CDataType* c_ptr, + index_t M, + index_t N, + index_t K, + index_t stride_a, + index_t stride_b, + index_t stride_c, + index_t scale_granularity_m, + index_t scale_granularity_n, + index_t scale_granularity_k, + float* scale_A_ptr, + float* scale_B_ptr) +{ + int totalElements = M * N; + int numThreadsPerBlock = 256; // Common choice for threads per block + int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; + + blockwise_gemm_kernel + <<>>(a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_a, + stride_b, + stride_c, + scale_granularity_m, + scale_granularity_n, + scale_granularity_k, + scale_A_ptr, + scale_B_ptr); + + return; +} + template , - sequence<0, 1>, - sequence>; + sequence<0, 1>, + sequence>; constexpr index_t num_access = SFC::get_num_of_access(); static_assert(std::is_same_v, @@ -334,8 +334,8 @@ struct CShuffleEpilogue const auto c_ds_tiles = concat_tuple_of_reference( tie(c_out_tensor, c_out_tensor), - generate_tie( - [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, + number{})); tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); @@ -360,7 +360,12 @@ struct CShuffleEpilogue } }); } - template + + template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, @@ -368,118 +373,133 @@ struct CShuffleEpilogue ScaleM scale_m, ScaleN scale_n) { - // const index_t iMWarp = get_warp_id() / kNWave; - // const index_t iNWarp = get_warp_id() - iMWarp * kNWave; - // const index_t iMLane = get_lane_id() / NPerXdl; - // const index_t iNLane = get_lane_id() % NPerXdl; + constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); - // constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); + auto lds_tile = make_static_distributed_tensor(LdsTileDistr); - // auto lds_tile = make_static_distributed_tensor(LdsTileDistr); + constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + auto o_lds_block = make_tensor_view( + static_cast(p_smem), lds_block_desc); - // constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); - // auto o_lds_block = make_tensor_view( - // static_cast(p_smem), lds_block_desc); + auto in_lds_window = make_tile_window( + o_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + LdsTileDistr); - // auto in_lds_window = make_tile_window( - // o_lds_block, - // make_tuple(number{}, number{}), - // {0, 0}, - // LdsTileDistr); + auto out_lds_window = make_tile_window( + o_lds_block, + make_tuple(number{}, number{}), + {0, 0}); - // auto out_lds_window = make_tile_window( - // o_lds_block, - // make_tuple(number{}, number{}), - // {0, 0}); + using SFC = space_filling_curve, + sequence<0, 1>, + sequence>; + constexpr index_t num_access = SFC::get_num_of_access(); - // using SFC = space_filling_curve, - // sequence<0, 1>, - // sequence>; - // constexpr index_t num_access = SFC::get_num_of_access(); + static_assert(std::is_same_v, + "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); - // static_assert(std::is_same_v, - // "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); + using TileEncodingPattern = + TileDistributionEncodingPattern2D; + constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); - // using TileEncodingPattern = - // TileDistributionEncodingPattern2D; - // constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + auto d_dram_windows = generate_tuple( + [&](auto idx) { + return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); + }, + number{}); - // auto d_dram_windows = generate_tuple( - // [&](auto idx) { - // return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); - // }, - // number{}); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // constexpr auto c_warp_y_lengths = - // to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - // constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr int kM2 = 4; // Val + constexpr int kM1 = (64 / NPerXdl); // Thr + constexpr int kM0 = MPerXdl / kM1; // Val - // static_for<0, num_access, 1>{}([&](auto iAccess) { - // block_sync_lds(); - // constexpr auto idx_y_start = SFC::get_index(iAccess); + const index_t iMWarp = get_warp_id() / NWave; + const index_t iNWarp = get_warp_id() - iMWarp * NWave; + const index_t iMLane = get_lane_id() / NPerXdl; + const index_t iNLane = get_lane_id() % NPerXdl; - // constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; - // constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; + static_for<0, num_access, 1>{}([&](auto iAccess) { + block_sync_lds(); + constexpr auto idx_y_start = SFC::get_index(iAccess); - // lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - // merge_sequences( - // sequence{}, - // c_warp_y_index_zeros), - // merge_sequences(sequence{}, - // c_warp_y_lengths)); + constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; + constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; - // const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); + lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_lengths)); - // store_tile(in_lds_window, c_warptile_in_tensor_casted); - // block_sync_lds(); + static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) { + float scale_B = + scale_n[nIter * NPerIterationShuffle + + iNWarp * NumNXdlPerWavePerShuffle * NPerXdl + n_xdl * NPerXdl + iNLane]; + static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { + constexpr int acc_xdl_offset = + (m_xdl * NumMXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product(); - // auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - - // auto m1 = iMLane; - // float scale_B = scale_n[nIter * NPerIterationShuffle]; - // static_for<0, kM0, 1>{}([&](auto m0) { - // static_for<0, kM2, 1>{}([&](auto m2) { - // float scale_A = scale_m[mIter * MPerIterationShuffle + iMWarp * MPerXdl + - // m0 * kM1 * kM2 + m1 * kM2 + m2]; - // c_out_tensor.get_thread_buffer()[m0 * kM2 + m2] *= scale_A * scale_B; - // }); - // }); + static_for<0, kM0, 1>{}([&](auto m0) { + static_for<0, kM2, 1>{}([&](auto m2) { + float scale_A = + scale_m[mIter * MPerIterationShuffle + + iMWarp * NumMXdlPerWavePerShuffle * MPerXdl + + m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + m2]; + lds_tile.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *= + scale_A * scale_B; + }); + }); + }); + }); - // const auto ds_tensor = generate_tuple( - // [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); + const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); - // const auto c_ds_tiles = concat_tuple_of_reference( - // tie(c_out_tensor, c_out_tensor), - // generate_tie( - // [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + store_tile(in_lds_window, c_warptile_in_tensor_casted); + block_sync_lds(); - // tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); + auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - // if constexpr(MemoryOperation == memory_operation_enum::set) - // { - // store_tile(out_dram_window, c_out_tensor); - // } - // else - // { - // update_tile(out_dram_window, c_out_tensor); - // } - // if constexpr(iAccess != num_access - 1) - // { - // constexpr auto step = SFC::get_forward_step(iAccess); + const auto ds_tensor = generate_tuple( + [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); - // move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); + const auto c_ds_tiles = concat_tuple_of_reference( + tie(c_out_tensor, c_out_tensor), + generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, + number{})); - // static_for<0, NumDTensor, 1>{}([&](auto idx) { - // move_tile_window(d_dram_windows[idx], - // {step.at(number<0>{}), step.at(number<1>{})}); - // }); - // } - // }); + tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); + + if constexpr(MemoryOperation == memory_operation_enum::set) + { + store_tile(out_dram_window, c_out_tensor); + } + else + { + update_tile(out_dram_window, c_out_tensor); + } + if constexpr(iAccess != num_access - 1) + { + constexpr auto step = SFC::get_forward_step(iAccess); + + move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); + + static_for<0, NumDTensor, 1>{}([&](auto idx) { + move_tile_window(d_dram_windows[idx], + {step.at(number<0>{}), step.at(number<1>{})}); + }); + } + }); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index eee2eeb769..ee79c4c4eb 100755 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -102,17 +102,17 @@ struct BaseFlatmmHostArgs { CK_TILE_HOST BaseFlatmmHostArgs() = default; CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_, - const void* b_ptr_, - const std::array& ds_ptr_, - void* e_ptr_, - index_t k_batch_, - index_t M_, - index_t N_, - index_t K_, - index_t stride_A_, - index_t stride_B_, - const std::array& stride_Ds_, - index_t stride_E_) + const void* b_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + const std::array& stride_Ds_, + index_t stride_E_) : a_ptr(a_ptr_), b_ptr(b_ptr_), ds_ptr(ds_ptr_), @@ -151,35 +151,49 @@ struct BaseFlatmmHostArgs index_t k_batch; }; -template , class ScaleN = FlatmmScalePointer<-1>, index_t NumDTensor = 0> +template , + class ScaleN = FlatmmScalePointer<-1>, + index_t NumDTensor = 0> struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<> { CK_TILE_HOST ScaleFlatmmHostArgs() = default; CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_, - const void* b_shuffle_ptr_, - const std::array& ds_ptr_, - void* c_ptr_, - index_t k_batch_, - index_t M_, - index_t N_, - index_t K_, - index_t stride_A_, - index_t stride_B_, - const std::array& stride_Ds_, - index_t stride_C_, - ScaleM scale_m_ = nullptr, - ScaleN scale_n_ = nullptr) - : BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_), - scale_m(scale_m_), - scale_n(scale_n_) + const void* b_shuffle_ptr_, + const std::array& ds_ptr_, + void* c_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + const std::array& stride_Ds_, + index_t stride_C_, + ScaleM scale_m_ = nullptr, + ScaleN scale_n_ = nullptr) + : BaseFlatmmHostArgs(a_ptr_, + b_shuffle_ptr_, + ds_ptr_, + c_ptr_, + k_batch_, + M_, + N_, + K_, + stride_A_, + stride_B_, + stride_Ds_, + stride_C_), + scale_m(scale_m_), + scale_n(scale_n_) { } ScaleM scale_m = nullptr; ScaleN scale_n = nullptr; }; -template -using FlatmmHostArgs = ScaleFlatmmHostArgs, FlatmmScalePointer<-1>, NumberTensor>; +template +using FlatmmHostArgs = + ScaleFlatmmHostArgs, FlatmmScalePointer<-1>, NumberTensor>; template struct FlatmmKernelArgs @@ -278,7 +292,8 @@ struct FlatmmKernel struct SplitKBatchOffset { template - __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { + __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) + { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); const index_t K_t = kargs.k_batch * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; @@ -681,16 +696,17 @@ struct FlatmmKernel } template - CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr_ping, - void* smem_ptr_pong, - const FlatmmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) + CK_TILE_DEVICE static void + RunFlatmm(const ADataType* a_ptr, + const BDataType* b_flat_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_ping, + void* smem_ptr_pong, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = @@ -712,19 +728,21 @@ struct FlatmmKernel if constexpr(ScaleM::granularity != -1 || ScaleN::granularity != -1) { auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template operator()( - c_block_window, - c_block_tile, - d_block_window, - smem_ptr_ping, - kargs.scale_m_ptr + block_idx_m, - kargs.scale_n_ptr + block_idx_n); + EpiloguePipeline{}.template + operator()( + c_block_window, + c_block_tile, + d_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); } else if(UseDefaultScheduler || (get_warp_id() == 0)) { // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template operator()( + EpiloguePipeline{}.template + operator()( c_block_window, c_block_tile, d_block_window, smem_ptr_ping); } } @@ -755,15 +773,15 @@ struct FlatmmKernel { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); RunFlatmm(a_ptr, - b_flat_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_ping, - smem_ptr_pong, - kargs, - splitk_batch_offset, - i_m, - i_n); + b_flat_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_ping, + smem_ptr_pong, + kargs, + splitk_batch_offset, + i_m, + i_n); } } }; From b908f5e803e5141534bf081a76bf0b160eb36568 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Wed, 23 Jul 2025 19:12:31 +0000 Subject: [PATCH 09/55] fix flatmm syntax error on gfx950 --- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1062 +++++++++-------- 1 file changed, 572 insertions(+), 490 deletions(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 1a2348810e..b8e923a52e 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -13,7 +13,7 @@ namespace ck_tile { template struct BaseFlatmmPipelineAGmemBGmemCRegV1 { - static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefetchStages = 2; CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) { @@ -25,19 +25,23 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; } template - CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num) + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool, TailNumber tail_num) { - if (TailNumber::Even == tail_num) + if(TailNumber::Even == tail_num) { - return run_func(bool_constant{}, integral_constant{}); + return run_func(bool_constant{}, + integral_constant{}); } - else if (TailNumber::Odd == tail_num) + else if(TailNumber::Odd == tail_num) { - return run_func(bool_constant{}, integral_constant{}); + return run_func(bool_constant{}, + integral_constant{}); } // assert(false); return run_func(bool_constant{}, integral_constant{}); - // return run_func(bool_constant{}, integral_constant{}); + // return run_func(bool_constant{}, integral_constant{}); } }; @@ -56,8 +60,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV using BlockFlatmm = remove_cvref_t())>; - - static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + static constexpr auto config = + BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -109,11 +114,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; - static constexpr index_t K1 = 16 / sizeof(ADataType); - static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; + static constexpr index_t K1 = 16 / sizeof(ADataType); + static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp; - static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum; - static constexpr index_t BloadGap = MIterPerWarp / 2; + static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum; + static constexpr index_t BloadGap = MIterPerWarp / 2; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; @@ -145,21 +150,21 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1 if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 16 && - std::is_same_v) || - (warp_m == 16 && warp_n == 16 && warp_k == 16 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 8 && - std::is_same_v)) + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 16 && + std::is_same_v) || + (warp_m == 16 && warp_n == 16 && warp_k == 16 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 8 && + std::is_same_v)) { return {2, 1}; } // K1 per Mfma = 2 cases: mfma_per_wg = 1, dsread_per_wg = 2 else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 128 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 64 && - std::is_same_v)) + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 64 && + std::is_same_v)) { return {1, 2}; } @@ -227,73 +232,73 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // instruction schedule example(128X256X256, 1X4, 16X16X128): // Iter MNK MFMA ds_read ds_write A_load b_load // -1 M6N3: 60 2 - - - - // -1 M7N0: 61 - - - - - // -1 M7N1: 62 - - - - - // -1 M7N2: 63 - - - - - // -1 M7N3: 64 4 - - - - // 0 M0N0K0: 1 - - - - - // 0 M0N1: 2 - - - 2 - // 0 M0N2: 3 - - - - - // 0 M0N3: 4 6 - - - - // 0 M1N0: 5 - - - - - // 0 M1N1: 6 - - - 4 - // 0 M1N2: 7 - - - - - // 0 M1N3: 8 8 - - - - // 0 M2N0: 9 - - - - - // 0 M2N1: 10 - - - 6 - // 0 M2N2: 11 - - - - - // 0 M2N3: 12 10 - - - - // 0 M3N0: 13 - 1 - - - // 0 M3N1: 14 - - - 8 - // 0 M3N2: 15 - - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 - - - - + // -1 M7N2: 63 - - - - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - - + // 0 M0N1: 2 - - - 2 + // 0 M0N2: 3 - - - - + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - - + // 0 M1N1: 6 - - - 4 + // 0 M1N2: 7 - - - - + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - - + // 0 M2N1: 10 - - - 6 + // 0 M2N2: 11 - - - - + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - - + // 0 M3N1: 14 - - - 8 + // 0 M3N2: 15 - - - - // 0 M3N3: 16 12 - - - - // 0 M4N0: 17 - 2 - - - // 0 M4N1: 18 - - - - - // 0 M4N2: 19 - - 1 - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 - - - - + // 0 M4N2: 19 - - 1 - // 0 M4N3: 20 14 - - - - // 0 M5N0: 21 - 3 - - - // 0 M5N1: 22 - - - - - // 0 M5N2: 23 - - 2 - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 - - - - + // 0 M5N2: 23 - - 2 - // 0 M5N3: 24 16 - - - - // 0 M6N0: 25 - 4 - - - // 0 M6N1: 26 - - - - - // 0 M6N2: 27 - - 3 - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 - - - - + // 0 M6N2: 27 - - 3 - // 0 M6N3: 28 17 - - - - // 0 M7N0: 29 - - - - - // 0 M7N1: 30 - - - - - // 0 M7N2: 31 - - 4 - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 - - - - + // 0 M7N2: 31 - - 4 - // 0 M7N3: 32 18 - - - - // 0 M0N0K1: 33 - - - - - // 0 M0N1: 34 - - - 10 - // 0 M0N2: 35 - - - - - // 0 M0N3: 36 20 - - - - // 0 M1N0: 37 - - - - - // 0 M1N1: 38 - - - 12 - // 0 M1N2: 39 - - - - - // 0 M1N3: 40 22 - - - - // 0 M2N0: 41 - - - - - // 0 M2N1: 42 - - - 14 - // 0 M2N2: 43 - - - - - // 0 M2N3: 44 24 - - - - // 0 M3N0: 45 - 5 - - - // 0 M3N1: 46 - - - 16 - // 0 M3N2: 47 - - - - + // 0 M0N0K1: 33 - - - - + // 0 M0N1: 34 - - - 10 + // 0 M0N2: 35 - - - - + // 0 M0N3: 36 20 - - - + // 0 M1N0: 37 - - - - + // 0 M1N1: 38 - - - 12 + // 0 M1N2: 39 - - - - + // 0 M1N3: 40 22 - - - + // 0 M2N0: 41 - - - - + // 0 M2N1: 42 - - - 14 + // 0 M2N2: 43 - - - - + // 0 M2N3: 44 24 - - - + // 0 M3N0: 45 - 5 - - + // 0 M3N1: 46 - - - 16 + // 0 M3N2: 47 - - - - // 0 M3N3: 48 26 - - - - // 0 M4N0: 49 - 6 - - - // 0 M4N1: 50 - - - - - // 0 M4N2: 51 - - 5 - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 - - - - + // 0 M4N2: 51 - - 5 - // 0 M4N3: 52 28 - - - - // 0 M5N0: 53 - 7 - - - // 0 M5N1: 54 - - - - - // 0 M5N2: 55 - - 6 - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 - - - - + // 0 M5N2: 55 - - 6 - // 0 M5N3: 56 30 - - - - // 0 M6N0: 57 - 8 - - - // 0 M6N1: 58 - - - - - // 0 M6N2: 59 - - 7 - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 - - - - + // 0 M6N2: 59 - - 7 - // 0 M6N3: 60 2 - - - - // 0 M7N0: 61 - - - - - // 0 M7N1: 62 - - - - - // 0 M7N2: 63 - - 8 - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 - - - - + // 0 M7N2: 63 - - 8 - // 0 M7N3: 64 4 - - - if constexpr(warp_m == 16 && warp_n == 16) { @@ -473,7 +478,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); - + static_for<0, 1, 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA @@ -531,7 +536,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA @@ -545,13 +550,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_barrier(0); } - } #endif + } } CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler() { - #if 0 +#if 0 static_for<0, 2, 1>{}([&](auto j) { ignore = j; static_for<0, 3, 1>{}([&](auto i) { @@ -593,7 +598,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); __builtin_amdgcn_sched_barrier(0); - #endif +#endif } template @@ -630,7 +635,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; __builtin_amdgcn_sched_barrier(0); - + // A tile in LDS ADataType* p_a_lds_ping = static_cast(p_smem_ping); ADataType* p_a_lds_pong = static_cast(p_smem_pong); @@ -638,11 +643,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor(); - auto a_lds_block_ping = make_tensor_view(p_a_lds_ping, a_lds_block_desc); - auto a_lds_block_pong = make_tensor_view(p_a_lds_pong, a_lds_block_desc); + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); - // A DRAM tile window for load - #ifndef FINEGRADE_LOADSTORE +// A DRAM tile window for load +#ifndef FINEGRADE_LOADSTORE auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -657,10 +664,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV auto a_copy_lds_window_pong = make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); - #else + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); +#else auto a_copy_dram_window_tmp = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -673,49 +680,49 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV move_tile_window(a_copy_dram_window(AIter), {AIter * AcopyPerLoadM, 0}); }); - auto a_copy_lds_window_ping_tmp = make_tile_window( - a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramDistribution() - ); + auto a_copy_lds_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramDistribution()); - statically_indexed_array a_copy_lds_window_ping; + statically_indexed_array + a_copy_lds_window_ping; static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { a_copy_lds_window_ping(AIter) = a_copy_lds_window_ping_tmp; move_tile_window(a_copy_lds_window_ping(AIter), {AIter * AcopyPerLoadM, 0}); }); - auto a_copy_lds_window_pong_tmp = make_tile_window( - a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramDistribution() - ); + auto a_copy_lds_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramDistribution()); - statically_indexed_array a_copy_lds_window_pong; + statically_indexed_array + a_copy_lds_window_pong; static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { a_copy_lds_window_pong(AIter) = a_copy_lds_window_pong_tmp; move_tile_window(a_copy_lds_window_pong(AIter), {AIter * AcopyPerLoadM, 0}); }); - #endif +#endif // A LDS tile for block GEMM // auto a_lds_gemm_window = make_tile_window( // a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}); // ping-pong window for A LDS - auto a_warp_window_ping_tmp = make_tile_window( - a_lds_block_ping, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - auto a_warp_window_pong_tmp = make_tile_window( - a_lds_block_pong, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); statically_indexed_array< statically_indexed_array, @@ -726,7 +733,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV statically_indexed_array, MIterPerWarp> a_warp_windows_pong; - + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; @@ -776,19 +783,19 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV NIterPerWarp> b_warp_tensor_pong; - - // Prefetch A0 - #ifndef FINEGRADE_LOADSTORE +// Prefetch A0 +#ifndef FINEGRADE_LOADSTORE auto a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - #else - statically_indexed_array{}))), ACopyLoadNum> a_block_tile; +#else + statically_indexed_array{}))), ACopyLoadNum> + a_block_tile; static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter)); move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock}); }); - #endif +#endif // prefetch B static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -796,7 +803,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); @@ -815,29 +822,31 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // } // else // { - // store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile)); + // store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, + // a_block_tile)); // } - #ifndef FINEGRADE_LOADSTORE +#ifndef FINEGRADE_LOADSTORE auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); - #else +#else static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { - store_tile(a_copy_lds_window_ping(AIter), tile_elementwise_in(a_element_func, a_block_tile(AIter))); + store_tile(a_copy_lds_window_ping(AIter), + tile_elementwise_in(a_element_func, a_block_tile(AIter))); }); - #endif +#endif __builtin_amdgcn_sched_barrier(0); - // Prefetch A1 - #ifndef FINEGRADE_LOADSTORE +// Prefetch A1 +#ifndef FINEGRADE_LOADSTORE a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - #else +#else static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter)); move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock}); }); - #endif +#endif // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -845,420 +854,493 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV block_sync_lds(); // preload A00,A10 from lds - constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2: 1; - statically_indexed_array{})(number<0>{}))), m_preload> a_warp_tensor_ping; - statically_indexed_array{})(number<0>{}))), m_preload> a_warp_tensor_pong; - + constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1; + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor_ping; + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor_pong; + static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); + a_warp_tensor_ping(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); // if(threadIdx.x==0){ // for(int i=0;i(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size()); + // printf("dteng--A buffer load: idx.x=%u, ablocktile=%f, buffer size=%d\n", + // threadIdx.x, + // type_convert(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size()); // } // } // for(int i=0;i{}).get_thread_buffer_size();i++) { - // printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n", threadIdx.x, type_convert(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size()); + // printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n", + // threadIdx.x, + // type_convert(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size()); // } - index_t iCounter = (num_loop - 1) / 2; // if constexpr(HasMainLoop) // { - while(iCounter > 0) - { - #ifndef FINEGRADE_LOADSTORE - // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + while(iCounter > 0) + { +#ifndef FINEGRADE_LOADSTORE + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); +#endif + + // GEMM 2i + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); - }); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // Prefill A(2i+1) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); - // Prefetch A(2i+2) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - #endif - - // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - #ifdef FINEGRADE_LOADSTORE - // prefetch B(2i+1) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(2i+1) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - // Prefetch A(2i+2) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2))) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - a_block_tile(number{}) = load_tile(a_copy_dram_window(number{})); - move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); - } - #endif - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) +#ifdef FINEGRADE_LOADSTORE + // prefetch B(2i+1) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && + ((curMNIter % BloadGap) == 1)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = + load_tile(b_flat_dram_windows(number{})(BkIter)); } - - //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + // Prefill A(2i+1) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && + (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) { - block_sync_lds(); + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + store_tile( + a_copy_lds_window_pong(number{}), + tile_elementwise_in(a_element_func, a_block_tile(number{}))); } + // Prefetch A(2i+2) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && + (mIter < (MIterPerWarp - 1 + 1)) && + ((nIter % NIterPerWarp) == (NIterPerWarp - 2))) + { + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + a_block_tile(number{}) = + load_tile(a_copy_dram_window(number{})); + move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); + } +#endif + __builtin_amdgcn_sched_barrier(0x7F6); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); - //block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + HotLoopScheduler(); + + // Next K + +#ifndef FINEGRADE_LOADSTORE + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); - HotLoopScheduler(); - - //Next K + }); - #ifndef FINEGRADE_LOADSTORE - // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); +#endif + + // GEMM 2i+1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_pong(number{}), + b_warp_tensor_pong(nIter)(kIter)); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // Prefill A(2i+2) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); - // Prefetch A(2i+3) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - #endif - - // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_pong(number{}), b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - #ifdef FINEGRADE_LOADSTORE - // prefetch B(2i+2) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(2i+1) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - store_tile(a_copy_lds_window_ping(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - // Prefetch A(2i+2) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2))) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - a_block_tile(number{}) = load_tile(a_copy_dram_window(number{})); - move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); - } - #endif - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) +#ifdef FINEGRADE_LOADSTORE + // prefetch B(2i+2) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && + ((curMNIter % BloadGap) == 1)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_ping(number{})(BkIter) = + load_tile(b_flat_dram_windows(number{})(BkIter)); } - - //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + // Prefill A(2i+1) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && + (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) { - block_sync_lds(); + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + store_tile( + a_copy_lds_window_ping(number{}), + tile_elementwise_in(a_element_func, a_block_tile(number{}))); } + // Prefetch A(2i+2) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && + (mIter < (MIterPerWarp - 1 + 1)) && + ((nIter % NIterPerWarp) == (NIterPerWarp - 2))) + { + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + a_block_tile(number{}) = + load_tile(a_copy_dram_window(number{})); + move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); + } +#endif + __builtin_amdgcn_sched_barrier(0x7F6); }); - }); - // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); - HotLoopScheduler(); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); - iCounter--; - } + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - // tail - if constexpr(TailNum == TailNumber::Even) - { - // __builtin_amdgcn_sched_barrier(0); - #ifndef FINEGRADE_LOADSTORE - // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + HotLoopScheduler(); + + iCounter--; + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { +// __builtin_amdgcn_sched_barrier(0); +#ifndef FINEGRADE_LOADSTORE + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); +#endif + + // GEMM loopK-1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + +#ifdef FINEGRADE_LOADSTORE + // prefetch B(loopK) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && + ((curMNIter % BloadGap) == 1)) + { + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = + load_tile(b_flat_dram_windows(number{})(BkIter)); + } + // Prefill A(loopK) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && + (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) + { + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + store_tile( + a_copy_lds_window_pong(number{}), + tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } +#endif + __builtin_amdgcn_sched_barrier(0x7F6); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); - // Prefill A(loopK) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); - #endif + TailHotLoopScheduler(); - // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - #ifdef FINEGRADE_LOADSTORE - // prefetch B(loopK) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(loopK) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - #endif - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); - } + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); - //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // __builtin_amdgcn_sched_barrier(0); + + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_pong(number{}), + b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); }); + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } }); - //block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); - TailHotLoopScheduler(); + // TailHotLoopScheduler(); + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); - }); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // __builtin_amdgcn_sched_barrier(0); - - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_pong(number{}), b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); - }); - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + +#ifdef FINEGRADE_LOADSTORE + // prefetch B(loopK) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && + ((curMNIter % BloadGap) == 1)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = + load_tile(b_flat_dram_windows(number{})(BkIter)); } + // Prefill A(loopK) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && + (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) + { + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + store_tile( + a_copy_lds_window_pong(number{}), + tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } +#endif + __builtin_amdgcn_sched_barrier(0x7F6); }); - }); - // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } - // TailHotLoopScheduler(); - // __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(TailNum == TailNumber::Odd) - { - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - #ifdef FINEGRADE_LOADSTORE - // prefetch B(loopK) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(loopK) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - #endif - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); - } - - //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); - } + }); + } // } return c_block_tile; @@ -1273,7 +1355,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_flat_dram_block_window_tmp, num_loop, p_smem_ping, From 68390988c99d094328e02bdf34517e5bd13060db Mon Sep 17 00:00:00 2001 From: solin Date: Thu, 24 Jul 2025 04:38:16 +0000 Subject: [PATCH 10/55] reorg flatmm code --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 353 +++++++++++++++++- example/ck_tile/18_flatmm/flatmm_basic.hpp | 38 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 348 ----------------- 3 files changed, 359 insertions(+), 380 deletions(-) delete mode 100644 example/ck_tile/18_flatmm/run_flatmm_example.inc diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 6f977a803c..dfe3d6c3cb 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -11,7 +11,142 @@ #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" -#include "run_flatmm_example.inc" +#include + +template +constexpr const char* DataTypeToString() +{ + if constexpr(std::is_same_v) + { + return "fp16"; + } + else if constexpr(std::is_same_v) + { + return "fp8"; + } + else if constexpr(std::is_same_v) + { + return "bf8"; + } + else if constexpr(std::is_same_v) + { + return "bf16"; + } + else + { + return "unknown"; + } +} + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +// mfma_type, 0:32x32, 1:16x16 +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, + FlatmmConfig::N_Warp_Tile, + k_ / FlatmmConfig::K_Warp_Tile, + divisor, + FlatmmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, + ck_tile::DeviceMem& b_shuffle_dev_buf, + ck_tile::DeviceMem& c_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(), + b_shuffle_dev_buf.GetDeviceBuffer(), + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C}; + + float ave_time = flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() + << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A + << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time + << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} template & args, const ck_tile::stream_c return ave_time; } +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "256", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "128", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("warp_tile", + "0", + "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +int run_flatmm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + // persistent not added + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_host( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_origin_host( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_rslt_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + // TODO: add different init types + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_host); + ck_tile::FillMonotonicSeq{}(b_origin_host); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); + } + else + { + a_host.SetZero(); + b_origin_host.SetZero(); + } + + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes()); + + a_dev_buf.ToDevice(a_host.data()); + c_rslt_host.SetZero(); + + // do pre-shuffle + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host); + ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); + b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); + + invoke_flatmm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_dev_buf, + b_shuffle_dev_buf, + c_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + c_dev_buf.FromDevice(c_rslt_host.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_ref_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_ref_host.SetZero(); + + ck_tile::reference_gemm( + a_host, b_origin_host, c_ref_host); + const float max_accumulated_value = + *std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, + c_ref_host, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes()); + b_origin_dev_buf.ToDevice(b_origin_host.data()); + + ck_tile::HostTensor c_gpu_ref_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes()); + c_gpu_ref_host.SetZero(); + c_gpu_ref_dev_buf.SetZero(); + + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy( + d_A, a_dev_buf.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_origin_dev_buf.GetDeviceBuffer(), + N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(), + d_C, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data()); + const float max_accumulated_value = + *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, + c_gpu_ref_host, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + template