From 99640365d8dfb9065cfffea8f3a109ba4a9f8fef Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Thu, 10 Oct 2024 03:02:22 -0700 Subject: [PATCH] Ck tile gemm cshuffle & CK Tile GEMM restructure (#1535) * ake the cshuffle compilable * modify Mhe reference on gpu and cpu. Correaccess of cshuffle * fix the cpu reference code * Complete the in tile shuffle logic * restructure the kernel template input * change the naming pattern of ck_tile gemm pipeline * Re-format files using remod.py * Solve the fmha conflict with gemm * Comment Addressed from Carlus --------- Co-authored-by: Po Yen, Chen [ROCm/composable_kernel commit: 6f27bc987248633255cc400437bd017dca70cf1e] --- example/ck_tile/03_gemm/gemm_basic.cpp | 55 ++++-- .../ck_tile/core/container/thread_buffer.hpp | 2 +- .../ck_tile/host/reference/reference_gemm.hpp | 47 ++++- include/ck_tile/ops/epilogue.hpp | 1 + .../ops/epilogue/cshuffle_epilogue.hpp | 171 ++++++++++++++++++ ...block_fmha_bwd_pipeline_default_policy.hpp | 133 ++++++++------ ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 81 +++++---- include/ck_tile/ops/gemm.hpp | 11 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 15 +- ... => gemm_pipeline_agmem_bgmem_creg_v1.hpp} | 10 +- ...ne_agmem_bgmem_creg_v1_default_policy.hpp} | 4 +- ... => gemm_pipeline_agmem_bgmem_creg_v2.hpp} | 6 +- ...ne_agmem_bgmem_creg_v2_default_policy.hpp} | 9 +- ..._problem.hpp => gemm_pipeline_problem.hpp} | 17 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 27 +++ 15 files changed, 447 insertions(+), 142 deletions(-) create mode 100644 include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v1.hpp => gemm_pipeline_agmem_bgmem_creg_v1.hpp} (95%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp => gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp} (99%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v2.hpp => gemm_pipeline_agmem_bgmem_creg_v2.hpp} (97%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp => gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp} (57%) rename include/ck_tile/ops/gemm/pipeline/{block_gemm_pipeline_problem.hpp => gemm_pipeline_problem.hpp} (65%) create mode 100644 include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 9f790f6acb..e3c8d72590 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -41,18 +41,39 @@ template ; - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; + + // The rank and permutation will also be generate out by the CodeGen part. + constexpr ck_tile::index_t kOutputRank = 2; + + // Whether doing the CShuffle (transpose before the global memory), depending on the output + // layout. + constexpr bool CShuffleEpilogue = + std::is_same_v; + + using GemmEpilogue = std::conditional_t< + CShuffleEpilogue, + ck_tile::CShuffleEpilogue>, + ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = - ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKargs(args.p_a, args.p_b, @@ -255,15 +276,13 @@ int main(int argc, char* argv[]) ck_tile::sequence, ck_tile::sequence>; - using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem; + using CodegenGemmTraits = ck_tile:: + TileGemmTraits; - using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1; + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; invoke_gemm c_host_gpu_ref(c_dimensions); ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); - ck_tile::reference_gemm_gpu( + ck_tile::reference_gemm_gpu( a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c); c_buf.FromDevice(c_host_gpu_ref.data()); diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp index a7dad5233b..279a48acb3 100644 --- a/include/ck_tile/core/container/thread_buffer.hpp +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -58,7 +58,7 @@ struct thread_buffer { template CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } template CK_TILE_HOST_DEVICE constexpr auto& at(number) { return get(I); } template CK_TILE_HOST_DEVICE constexpr const auto& at(number) const { return get(I); } - + template ::value, bool>::type = false> CK_TILE_HOST_DEVICE constexpr auto _get_as() const diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index a0ddd02d9e..a496c91e00 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, const BElementOp& b_element_op = {}, const ACCElementOp& acc_element_op = {}) { - const int N = b_n_k.mDesc.get_lengths()[0]; + const int N = (std::is_same_v) + ? b_n_k.mDesc.get_lengths()[0] + : b_n_k.mDesc.get_lengths()[1]; const int K = (std::is_same_v) ? a_m_k.mDesc.get_lengths()[1] : a_m_k.mDesc.get_lengths()[0]; @@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, ADataType v_a = (std::is_same_v) ? a_element_op(a_m_k(m, k)) : a_element_op(a_m_k(k, m)); - BDataType v_b = b_element_op(b_n_k(n, k)); + BDataType v_b = (std::is_same_v) + ? b_element_op(b_n_k(n, k)) + : b_element_op(b_n_k(k, n)); v_acc += ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); } - c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + CDataType& c_ref = (std::is_same_v) + ? c_m_n(m, n) + : c_m_n(n, m); + c_ref = ck_tile::type_convert(acc_element_op(v_acc)); } }; make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); } -template +template __global__ void naive_gemm_kernel(ADataType* A, BDataType* B, CDataType* C, @@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A, if(row < M && col < N) { AccDataType acc = 0.0; - for(int k = 0; k < K; ++k) { - acc += static_cast(A[row * strideA + k]) * - static_cast(B[col * strideB + k]); + // Adjust indexing based on matrix layout + int a_index = (std::is_same_v) + ? row * strideA + k + : k * strideA + row; + int b_index = (std::is_same_v) + ? col * strideB + k + : k * strideB + col; + acc += static_cast(A[a_index]) * static_cast(B[b_index]); } - C[row * strideC + col] = acc; // Store as AccDataType + int c_index = (std::is_same_v) + ? row * strideC + col + : col * strideC + row; + C[c_index] = acc; } } -template +template void reference_gemm_gpu(DeviceMem& a_device, DeviceMem& b_device, DeviceMem& c_device, @@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device, int numThreadsPerBlock = 256; // Common choice for threads per block int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; - naive_gemm_kernel + naive_gemm_kernel <<>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); errC = hipMemcpy( c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 388f52c898..a98f60b364 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -3,5 +3,6 @@ #pragma once +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp new file mode 100644 index 0000000000..9625b137bd --- /dev/null +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +#define CK_TILE_MAX_RANK 5 + +namespace ck_tile { + +// this epilogue aiming to store a matrix with different layout from the shared memory to the global +// memory. +template +struct CShuffleEpilogueProblem +{ + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kTilePermute = kTilePermute_; + static constexpr index_t kRank = kRank_; + static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4}; + static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = { + TileSize0, TileSize1, TileSize2, TileSize3, TileSize4}; +}; + +template +struct CShuffleEpilogue +{ + using Problem = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + const index_t* kPerm = Problem::kPerm; + static constexpr bool kTilePermute = Problem::kTilePermute; + static constexpr index_t kRank = Problem::kRank; + const index_t* tile_sizes = Problem::tile_sizes; + + // No additional shared memory needed + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + + template + CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) + { + using DataType = typename OAccTile::DataType; + + // Get thread buffer + auto& thread_buf = o_acc_tile.get_thread_buffer(); + + // Create a temporary buffer to hold the permuted data + thread_buffer permuted_thread_buf; + + // Get the lengths of each dimension + auto thread_tensor_lengths = o_acc_tile.get_lengths(); + + // Total number of elements + index_t total_elements = OAccTile::kThreadElementSpaceSize; + + // Iterate over all elements + for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) + { + // Convert linear index to multi-dimensional indices + array indices; + index_t remaining = linear_idx; + static_for<0, kRank, 1>{}([&](auto i) { + constexpr auto rev_i = kRank - 1 - i; + indices(rev_i) = remaining % thread_tensor_lengths.get(number{}); + remaining /= thread_tensor_lengths.get(number{}); + }); + + // Apply the permutation + array permuted_indices; + static_for<0, kRank, 1>{}( + [&](auto i) { permuted_indices(i) = indices.get(number{}); }); + + // Compute offsets + index_t dst_offset = 0; + index_t stride = 1; + + static_for<0, kRank, 1>{}([&](auto i) { + constexpr auto rev_i = kRank - 1 - i; + dst_offset += permuted_indices[rev_i] * stride; + stride *= thread_tensor_lengths.get(number{}); + }); + + // Move the data + permuted_thread_buf(dst_offset) = thread_buf[linear_idx]; + } + + // Copy the permuted data back to the original thread buffer + for(index_t i = 0; i < total_elements; ++i) + { + thread_buf.set_as(i, permuted_thread_buf.get(i)); + } + } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) + { + const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); + + // Compute the tile coordinates by dividing the window origin by the tile sizes + index_t tile_coords[CK_TILE_MAX_RANK] = {0}; + for(index_t i = 0; i < kRank; ++i) + { + tile_coords[i] = current_window_origin[i] / tile_sizes[i]; + // printf("The tile_coord is: %d", tile_coords[i]); + } + + // Apply the permutation to the tile coordinates + index_t permuted_tile_coords[CK_TILE_MAX_RANK]; + for(index_t i = 0; i < kRank; ++i) + { + permuted_tile_coords[i] = tile_coords[kPerm[i]]; + // printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]); + } + + // Compute the permuted window origin + index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0}; + for(index_t i = 0; i < kRank; ++i) + { + permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i]; + // printf("The new permuted_window_origin is: %d", permuted_window_origin[i]); + } + + typename ODramWindowTmp::BottomTensorIndex step = {}; + for(index_t i = 0; i < kRank; ++i) + { + step[i] = permuted_window_origin[i] - current_window_origin[i]; + } + + // Move the window + move_tile_window(o_dram_window_tmp, step); + + // Permute the data within the tile if necessary + if constexpr(kTilePermute) + { + permute_tile_data(o_acc_tile); + } + + // Store the tile data to the permuted location + if constexpr(kPadM || kPadN) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + buffer_store_fence(); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 8647a7d25a..e1f05d39db 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -5,8 +5,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" @@ -25,15 +26,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::QDataType, - typename Problem::KDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::QDataType, @@ -52,21 +59,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, WarpGemm>; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::GemmDataType, - typename Problem::OGradDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::OGradDataType, - typename Problem::VDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm2BlockWarps, - typename Problem::BlockFmhaShape::Gemm2WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm2BlockWarps, + typename Problem::BlockFmhaShape::Gemm2WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::OGradDataType, @@ -117,21 +136,27 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::BlockFmhaShape::Gemm2BlockWarps, WarpGemm>; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::GemmDataType, - typename Problem::QDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm3BlockWarps, - typename Problem::BlockFmhaShape::Gemm3WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm3BlockWarps, + typename Problem::BlockFmhaShape::Gemm3WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } template CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::GemmDataType, - typename Problem::KDataType, - typename Problem::AccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm4BlockWarps, - typename Problem::BlockFmhaShape::Gemm4WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm4BlockWarps, + typename Problem::BlockFmhaShape::Gemm4WarpTile>, + TileGemmTraits>; using WarpGemm = WarpGemmMfmaDispatcher; - return BlockGemmARegBRegCRegV1{}; + return BlockGemmARegBRegCRegV1{}; } // these are for global load diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index ae9e320f67..4ea0c4c9f2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -5,8 +5,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" @@ -75,15 +76,21 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::QDataType, - typename Problem::KDataType, - typename Problem::SaccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + TileGemmTraits>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -116,7 +123,7 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>; - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmARegBSmemCRegV2{}; } }; @@ -199,15 +206,21 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::QDataType, - typename Problem::KDataType, - typename Problem::SaccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm0BlockWarps, - typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>, + TileGemmTraits>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -240,7 +253,7 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>; - return BlockGemmASmemBSmemCRegV1{}; + return BlockGemmASmemBSmemCRegV1{}; } }; @@ -954,15 +967,21 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { - using BlockGemmProblem = BlockGemmPipelineProblem< - typename Problem::PDataType, - typename Problem::VDataType, - typename Problem::OaccDataType, - TileGemmShape, - typename Problem::BlockFmhaShape::Gemm1BlockWarps, - typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + using GemmProblem = + GemmPipelineProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>, + TileGemmTraits>; auto warp_gemm = [&]() { if constexpr(std::is_same_v && @@ -996,7 +1015,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmARegBSmemCRegV2{}; } }; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index e9005462b0..dc5983e4d1 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -23,12 +23,13 @@ #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index e24d7f9ea0..48329c8ba5 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -11,20 +11,12 @@ namespace ck_tile { -template +template struct GemmKernel { using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - using LayoutA = remove_cvref_t; - using LayoutB = remove_cvref_t; - using LayoutC = remove_cvref_t; static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize; using ADataType = remove_cvref_t; @@ -32,6 +24,10 @@ struct GemmKernel using CAccDataType = remove_cvref_t; using CODataType = remove_cvref_t; + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + __host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) { return TilePartitioner::GridSize(M_size, N_size, Batch_size); @@ -184,6 +180,7 @@ struct GemmKernel c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); + EpiloguePipeline{}(CBlockWindow_pad, acc); } }; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp similarity index 95% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index bec8a204cc..5ed7d036ea 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -4,15 +4,15 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCRegV1 +template +struct GemmPipelineAGmemBGmemCRegV1 { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadC = Problem::kPadC; + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() { return ck_tile::integer_divide_ceil( diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp similarity index 99% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 3048adad67..8639f00fbb 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -7,9 +7,9 @@ namespace ck_tile { -// Default policy for BlockGemmPipelineAGmemBGmemCRegV1 +// Default policy for GemmPipelineAGmemBGmemCRegV1 // Default policy class should not be templated, put template on member functions instead -struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy +struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { #if 0 // 2d diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp similarity index 97% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index ab5fe79114..bff7fc0a0e 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,15 +4,15 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template -struct BlockGemmPipelineAGmemBGmemCRegV2 +template +struct GemmPipelineAGmemBGmemCRegV2 { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp similarity index 57% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp index 0596408501..7dad55d6b9 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp @@ -7,12 +7,11 @@ namespace ck_tile { -// Default policy for BlockGemmPipelineAGmemBGmemCRegV2 +// Default policy for GemmPipelineAGmemBGmemCRegV2 // Default policy class should not be templated, put template on member functions instead // NOTE: policy should be binded to its corresponding operation. It's just a coincidence that -// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as -// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy -using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy = - BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy; +// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as +// GemmPipelineAGmemBGmemCRegV1DefaultPolicy +using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp similarity index 65% rename from include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp rename to include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 8dfba08ad7..d7b3b24a4a 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -13,20 +13,23 @@ template -struct BlockGemmPipelineProblem + typename TileGemmTraits_> +struct GemmPipelineProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using GemmTraits = remove_cvref_t; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - static constexpr bool kPadA = kPadA_; - static constexpr bool kPadB = kPadB_; - static constexpr bool kPadC = kPadC_; + static constexpr bool kPadA = GemmTraits::kPadA; + static constexpr bool kPadB = GemmTraits::kPadB; + static constexpr bool kPadC = GemmTraits::kPadC; + + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType); static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType); diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp new file mode 100644 index 0000000000..98da1510c7 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileGemmTraits +{ + static constexpr bool kPadA = kPadA_; + static constexpr bool kPadB = kPadB_; + static constexpr bool kPadC = kPadC_; + + using LayoutA = LayoutA_; + using LayoutB = LayoutB_; + using LayoutC = LayoutC_; +}; + +} // namespace ck_tile