diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index ed2006d4b9..97d54d826d 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -238,9 +238,9 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase template struct GemmConfigComputeV5 : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + static constexpr ck_tile::index_t M_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 8; //64 / sizeof(PrecType); static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 1; @@ -248,11 +248,13 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = 8; //get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + + static constexpr ck_tile::index_t PingPongDim = 1; // 0 - Off, 1 - M, 2 - N and 3 - K }; template @@ -486,7 +488,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index b80d9991d4..b68c852b25 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -60,7 +60,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) GemmConfig::UseStructuredSparsity, Persistent, GemmConfig::NumWaveGroups, - GemmConfig::Preshuffle>; + GemmConfig::Preshuffle, + GemmConfig::PingPongDim>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -94,6 +95,24 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; + /* + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem< + ADataType, + BDataType, + AccDataType, + CDataType, + CLayout, + GemmConfig::kPadM, + GemmConfig::kPadN, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC, + false, + memory_operation>>; + */ + using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } } else { @@ -210,61 +250,18 @@ int run_gemm_example_prec_type(std::string a_layout, using Col = ck_tile::tensor_layout::gemm::ColumnMajor; bool preshuffle = GemmConfig::Preshuffle; - if(preshuffle && std::is_same_v) - { - throw std::runtime_error("Preshuffle is not supported for this int4 datatype!"); - } + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - if(preshuffle && a_layout != "R" && b_layout != "C") - { - throw std::runtime_error( - "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); - } - if constexpr(std::is_same_v) + if(a_layout == "R" && b_layout == "C") { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices when " - "BPrecType is ck_tile::pk_int4_t!"); - } + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); } else { - if(a_layout == "R" && b_layout == "R") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Row{}, Row{}); - } - else if(a_layout == "R" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Row{}, Col{}, Row{}); - } - else if(a_layout == "C" && b_layout == "R") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Row{}, Row{}); - } - else if(a_layout == "C" && b_layout == "C") - { - return run_gemm_example_with_layouts( - arg_parser, Col{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } + throw std::runtime_error("Unsupported memory layout for the input matrices!"); } } @@ -275,52 +272,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - if(data_type == "fp16") - { - return run_gemm_example_prec_type, ck_tile::half_t>( - a_layout, b_layout, arg_parser); - } - else if(data_type == "bf16") + if(data_type == "bf16") { return run_gemm_example_prec_type, ck_tile::bf16_t>( a_layout, b_layout, arg_parser); } - else if(data_type == "fp8") - { - return run_gemm_example_prec_type, - ck_tile::fp8_t, - ck_tile::fp8_t, - ck_tile::half_t>(a_layout, b_layout, arg_parser); - } - else if(data_type == "bf8") - { - return run_gemm_example_prec_type, - ck_tile::bf8_t, - ck_tile::bf8_t, - ck_tile::half_t>(a_layout, b_layout, arg_parser); - } - else if(data_type == "int8") - { - return run_gemm_example_prec_type, - ck_tile::int8_t, - ck_tile::int8_t, - ck_tile::int32_t>(a_layout, b_layout, arg_parser); - } - else if(data_type == "pk_int4_t") - { - // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) - { - return run_gemm_example_prec_type, - ck_tile::half_t, - ck_tile::pk_int4_t, - ck_tile::half_t>(a_layout, b_layout, arg_parser); - } - else - { - throw std::runtime_error("Unsupported pipeline for this operation !!!"); - } - } else { throw std::runtime_error("Unsupported data type for this operation !!!"); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index e37b4f36d4..6b2a5129f0 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -130,6 +130,11 @@ struct GemmKernel return UniversalGemmKernel::MaxOccupancyGridSize(s); } + CK_TILE_HOST static auto PingPongGridSize(index_t, index_t N, index_t K, index_t KBatch) -> dim3 + { + return dim3(TilePartitioner::GridSize(N, K), 1, KBatch); + } + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return UniversalGemmKernel::BlockSize(); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index b621468e92..5b105b9a4d 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -264,6 +264,11 @@ struct GemmSpatiallyLocalTilePartitioner return integer_divide_ceil(K, KPerBlock); } + CK_TILE_HOST_DEVICE auto GetPingPongMLoops(index_t NumWavefronts) noexcept -> index_t + { + return integer_divide_ceil(M, MPerBlock * NumWavefronts); + } + /** * @brief Calculate workgroup 1D index mapping into 2D output C-tile space. * diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 8117d65758..f1c7142091 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -284,6 +284,11 @@ struct UniversalGemmKernel return dim3(grid_size, 1, 1); } + CK_TILE_HOST static auto PingPongGridSize(index_t, index_t N, index_t K, index_t KBatch) -> dim3 + { + return dim3(TilePartitioner::GridSize(N, K), 1, KBatch); + } + CK_TILE_HOST static auto BlockSize() { if(ck_tile::is_wave32()) @@ -845,6 +850,90 @@ struct UniversalGemmKernel } } + template + CK_TILE_DEVICE static auto MakePingPongGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = generate_tuple( + [&](auto i) { + const auto& a_tensor_view = views.at(I0); + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + }, + number{}); + + const auto& b_pad_view = generate_tuple( + [&](auto i) { + const auto& b_tensor_view = views.at(I1); + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, number{}), + sequence{}); + } + }, + number>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // TODO vector write in for C in ColMajor + const auto& e_pad_view = [&]() { + const auto& e_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view( + e_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + e_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, b_pad_view, d_pad_view, e_pad_view); + } + template CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) @@ -934,6 +1023,79 @@ struct UniversalGemmKernel return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); } + template + CK_TILE_DEVICE static auto MakePingPongGemmTileWindows + (const PadView& views, const index_t i_n, const index_t i_k, [[maybe_unused]] const index_t M, [[maybe_unused]] const index_t N, [[maybe_unused]] const index_t K) + { + const auto& as_pad_view = views.at(I0); + const auto& bs_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& e_pad_view = views.at(I3); + + const auto& as_block_window = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window( + as_pad_view[i], make_tuple(number{}, number{}), {0, i_k}); + } + else + { + return make_tile_window( + as_pad_view[i], make_tuple(number{}, number{}), {i_k, 0}); + } + }, + number{}); + + const auto& bs_block_window = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window( + bs_pad_view[i], + make_tuple(number{}, number{}), + {i_n, i_k}); + } + else + { + return make_tile_window( + b_pad_view[i], + make_tuple(number{}, number{}), + {i_k, i_n}); + } + }, + number{}); + + const auto& ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v{}, number{}), + {i_m, i_n}); + } + else + { + return make_tile_window( + ds_pad_view[i], + make_tuple(number{}, number{}), + {i_n, i_m}); + } + }, + number{}); + + auto e_block_window = make_tile_window( + e_pad_view, + make_tuple(number{}, number{}), + {0, i_n}); + + return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); + } + /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * @@ -987,6 +1149,43 @@ struct UniversalGemmKernel } } + CK_TILE_DEVICE static void PingPongGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const GemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_n, + const index_t block_idx_k) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = + MakePingPongGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + MakePingPongGemmTileWindows(gemm_pad_views, block_idx_n, block_idx_k, kargs.M, kargs.N, kargs.K); + + const index_t num_loop = __builtin_amdgcn_readfirstlane(integer_divide_ceil( + //kargs.M, TilePartitioner::MPerBlock * GemmPipeline::BlockGemmShape::NumWarps)); + kargs.M, TilePartitioner::MPerBlock)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + auto& e_block_window = gemm_tile_windows.at(I3); + + const auto EpilogueFunc = [&](auto &window, auto& tile) { + EpiloguePipeline{}.template operator()( + window, tile, smem_ptr_0); + }; + + GemmPipeline{}.template operator()( + a_block_window, b_block_window, e_block_window, num_loop, smem_ptr_0, EpilogueFunc); + } + /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * @@ -1045,9 +1244,9 @@ struct UniversalGemmKernel CK_TILE_DEVICE void operator()(KernelArgs kargs) const { const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const auto [iN, iK] = TilePartitioner{kargs.N, kargs.K}.GetOutputTileIndex(blockId); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_k = __builtin_amdgcn_readfirstlane(iK * TilePartitioner::KPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); @@ -1075,43 +1274,8 @@ struct UniversalGemmKernel // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } + PingPongGemm( + a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_n, i_k); } // Persistent kernel entry point diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index b05145890f..15c290f57f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -49,6 +49,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 using CLayout = remove_cvref_t; static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t PingPongDim = Problem::PingPongDim; using BlockGemm = remove_cvref_t())>; using I0 = number<0>; @@ -85,6 +86,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 static constexpr auto Scheduler = Problem::Scheduler; static constexpr index_t NumWarps = BlockGemmShape::NumWarps; + static constexpr index_t WaveStep = NumWarps * MPerBlock; static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{}); [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -116,6 +118,213 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 { using Base = PipelineImplBase; + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + [[maybe_unused]] CDramBlockWindowTmp& c_dram_block_window_tmp, + index_t num_loop, + void* __restrict__ p_smem_0, + [[maybe_unused]] const EpilogueFunction& epilogue_func) const + { + //static_assert((MPerBlock * num_loop * NumWarps) == Problem::kM, + // "Ping Pong Warps, Tile size and Block size for M dimension does not match."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + index_t warp_id = get_warp_id(); + index_t operation_id = + __builtin_amdgcn_readfirstlane(get_warp_id()); // 0 - Memory read, 1 - block-gemm + + auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(MPerBlock, 0); + auto c_offset = (warp_id == 0) ? make_array(0, 0) : make_array(MPerBlock, 0); + + auto tensor_views = + Base::GetABLdsTensorViews(static_cast(static_cast(p_smem_0))); + auto& a_lds_block = tensor_views.get(number<0>{}); + auto& b_lds_block = tensor_views.get(number<1>{}); + + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto a_windows = Base::GetAWindows( + a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr, a_offset); + auto& a_copy_dram_window = a_windows.get(number<0>{}); + auto& a_copy_lds_window = a_windows.get(number<1>{}); + auto& a_lds_window = a_windows.get(number<2>{}); + + auto b_windows = Base::GetBWindows( + b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + auto& b_copy_dram_window = b_windows.get(number<0>{}); + auto& b_copy_lds_window = b_windows.get(number<1>{}); + auto& b_lds_window = b_windows.get(number<2>{}); + + auto epilogue_dram_window = make_tile_window(c_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(MPerBlock, NPerBlock), + c_dram_block_window_tmp.get_window_origin() + c_offset); + + // Add the offset which is warp specific so that subsequently we can increase it + // with a fixed step size, which is also independent of the warp id. + //c_dram_block_window_tmp += c_offset; + //move_tile_window(c_dram_block_window_tmp, c_offset); + + // DRAM window steps. + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(MPerBlock * NumWarps, 0) + : make_array(0, MPerBlock * NumWarps); + + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, 0); + + using CDramTileWindowStep = typename CDramBlockWindowTmp::BottomTensorIndex; + constexpr CDramTileWindowStep c_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock * NumWarps, 0) : make_array(0, KPerBlock * NumWarps); + + constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using AGemmTile = decltype(make_static_distributed_tensor(AGemmTileDistr)); + using BGemmTile = decltype(make_static_distributed_tensor(BGemmTileDistr)); + AGemmTile a_tile_0, a_tile_1; + BGemmTile b_tile; + + // Register tile for A and B. + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + ABlockTile a_global_load_tile; + BBlockTile b_global_load_tile; + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile_0 = block_gemm.MakeCBlockTile(); + auto c_block_tile_1 = block_gemm.MakeCBlockTile(); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 1; }, c_block_tile_0); + tile_elementwise_inout([](auto& c) { c = 2; }, c_block_tile_1); + + auto BReadOps = [&](){ + Base::GlobalPrefetch( + b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_global_load_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func); + } + Base::LocalPrefetch(b_tile, b_lds_window); + }; + + // define ping, pong steps here as lambda functions. + auto MemoryOpsStep = [&](auto idx) { + // Memory read half here. + Base::GlobalPrefetch( + a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func); + } + + if(idx == 0) + { + Base::LocalPrefetch(a_tile_0, a_lds_window); + } + else + { + Base::LocalPrefetch(a_tile_1, a_lds_window); + } + }; + + auto ComputeStep = [&](auto idx) { + if(idx == 0) + { + tile_elementwise_inout([](auto& c) { c = 1; }, c_block_tile_0); + //block_gemm(c_block_tile_0, a_tile_0, b_tile); + + epilogue_func(epilogue_dram_window, c_block_tile_0); + } + else + { + tile_elementwise_inout([](auto& c) { c = 1; }, c_block_tile_1); + //block_gemm(c_block_tile_1, a_tile_1, b_tile); + + epilogue_func(epilogue_dram_window, c_block_tile_1); + } + }; + + // Read B block tile + BReadOps(); + + if(operation_id == 0) + { + MemoryOpsStep(warp_id); + } + + index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop); + while(num_compute_steps > 100) + { + block_sync_lds(); + operation_id = (operation_id + 1) % NumWaveGroups; + + if(operation_id == 0) + { + MemoryOpsStep(warp_id); + //move_tile_window(c_dram_block_window_tmp, {WaveStep, 0}); + epilogue_dram_window = make_tile_window(epilogue_dram_window.get_bottom_tensor_view(), + make_tuple(MPerBlock, NPerBlock), + epilogue_dram_window.get_window_origin() + c_dram_tile_window_step); + } + else + { + ComputeStep(warp_id); + } + + num_compute_steps -= 1; + } + block_sync_lds(); + + + if(operation_id == 0) + { + ComputeStep(warp_id); + } + + } + template template + typename BElementFunction, + typename EpilogueFunction> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, + const CDramBlockWindowTmp& c_dram_block_window_tmp, index_t num_loop, - void* p_smem_0) const + void* p_smem_0, + const EpilogueFunction& epilogue_func) const { - return PipelineImpl{}.template operator()( + return PipelineImpl{}.template operator()( a_dram_block_window_tmp, a_element_func, b_dram_block_window_tmp, b_element_func, + c_dram_block_window_tmp, num_loop, - p_smem_0); + p_smem_0, + epilogue_func); } public: - template + template CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, + const CDramBlockWindowTmp& c_dram_block_window_tmp, const index_t num_loop, - void* __restrict__ p_smem_0) const + void* __restrict__ p_smem_0, + const EpilogueFunction& epilogue_func) const { - return PipelineImpl{}.template operator()( + return PipelineImpl{}.template operator()( a_dram_block_window_tmp, [](const ADataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, + c_dram_block_window_tmp, num_loop, - p_smem_0); - } + p_smem_0, + epilogue_func); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 52bd07c9e2..8e217bcd80 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -212,6 +212,8 @@ struct UniversalGemmPipelineProblem static constexpr bool TransposeC = Traits::TransposeC; static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; + static constexpr index_t PingPongDim = Traits::PingPongDim; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index be777df6a6..92d2a2f93c 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -43,7 +43,8 @@ template + bool Preshuffle_ = 0, + index_t PingPongDim = 0> struct TileGemmUniversalTraits { static constexpr bool kPadM = kPadM_; @@ -61,6 +62,7 @@ struct TileGemmUniversalTraits static constexpr bool UsePersistentKernel = UsePersistentKernel_; static constexpr index_t NumWaveGroups = NumWaveGroups_; static constexpr bool Preshuffle = Preshuffle_; + static constexpr index_t PingPongDim = PingPongDim_; }; template