From bcf31d9b27d33304658425433d2521f583d39b2c Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Sat, 7 Sep 2024 01:23:32 -0700 Subject: [PATCH] Ck tile gemm example (#1488) * Checkpoint: Finished with the tile example & kernel verification, working on the different matrix layout * Finished the Matrix Layout feature set up. Note: Need to modify the inner block to solve the shuffle problem in the future. * Fix: Clang Format, API fixed from fmha * fix with better naming convention * revert back the pipeline code of fmha * Fixed: Addressed the comments and merge the GEMM shape of GEMM Operator and FMHA Operator to one. * clang format with the reference_gemm file * convert the clang format with the remod.py * Changed the format and variable name of the kernel gemm_shape and partitioner --------- Co-authored-by: thomasning [ROCm/composable_kernel commit: caacd388302dfc4d3d49f51163ff15924d101bbb] --- example/ck_tile/03_gemm/CMakeLists.txt | 2 + example/ck_tile/03_gemm/README.md | 23 ++ example/ck_tile/03_gemm/gemm_basic.cpp | 274 ++++++++++++++++++ example/ck_tile/03_gemm/gemm_basic.hpp | 71 +++++ example/ck_tile/CMakeLists.txt | 1 + .../ck_tile/host/reference/reference_gemm.hpp | 17 +- ...block_fmha_bwd_pipeline_default_policy.hpp | 85 +++--- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 51 ++-- include/ck_tile/ops/gemm.hpp | 2 + .../block/block_gemm_areg_bgmem_creg_v1.hpp | 11 +- ...emm_asmem_bsmem_creg_v1_default_policy.hpp | 4 + .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 176 +++++++++++ .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 38 +++ ...lock_gemm_pipeline_agmem_bgmem_creg_v1.hpp | 18 +- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 43 ++- ...lock_gemm_pipeline_agmem_bgmem_creg_v2.hpp | 3 +- .../pipeline/block_gemm_pipeline_problem.hpp | 17 +- .../ops/gemm/pipeline/tile_gemm_shape.hpp | 14 +- 18 files changed, 758 insertions(+), 92 deletions(-) create mode 100644 example/ck_tile/03_gemm/CMakeLists.txt create mode 100644 example/ck_tile/03_gemm/README.md create mode 100644 example/ck_tile/03_gemm/gemm_basic.cpp create mode 100644 example/ck_tile/03_gemm/gemm_basic.hpp create mode 100644 include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp create mode 100644 include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt new file mode 100644 index 0000000000..03fc9c7eb1 --- /dev/null +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -0,0 +1,2 @@ +set(CMAKE_BUILD_TYPE Debug) +add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) \ No newline at end of file diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md new file mode 100644 index 0000000000..00303bf62c --- /dev/null +++ b/example/ck_tile/03_gemm/README.md @@ -0,0 +1,23 @@ +# GEMM Matrix Multiplication + +This folder contains example for GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_gemm_basic -j +``` +This will result in an executable `build/bin/tile_example_gemm_basic` + +## example +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -k k dimension (default:64) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp new file mode 100644 index 0000000000..734ba0fe65 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -0,0 +1,274 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_basic.hpp" +#include "ck_tile/host.hpp" + +#include +#include +#include +#include +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("b", "1", "batch size") + .insert("m", "1024", "m dimension") + .insert("n", "2048", "n dimension") + .insert("k", "64", "k dimension") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "cpu validation or not") + .insert("e", "1e-5", "Absolute error tolerance") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "10", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) +{ + // ToDo: This will be modified by the codegen code later. + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadA = true; + constexpr bool kPadB = true; + constexpr bool kPadC = false; + + constexpr int kBlockPerCu = 1; + + // =============================================== + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTilePartitioner; + using PipelineProblem = ck_tile:: + BlockGemmPipelineProblem; + // The GemmPipeline should also come from the Codegen. + using GemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>; + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = + ck_tile::GemmKernel; + + auto kargs = Kernel::MakeKargs(args.p_a, + args.p_b, + args.p_c, + args.epsilon, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + args.stride_C); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + constexpr dim3 blocks = Kernel::BlockSize(); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +template +float invoke_gemm(ck_tile::DeviceMem& a_buf, + ck_tile::DeviceMem& b_buf, + ck_tile::DeviceMem& c_buf, + const ck_tile::ArgParser& arg_parser) +{ + + std::string data_type = arg_parser.get_str("prec"); + + if(data_type != DataTypeTraits::name) + { + std::cerr << "Data type mismatch: expected " << DataTypeTraits::name << ", got " + << data_type << std::endl; + return -1; // Or handle the error appropriately + } + + float epsilon = arg_parser.get_float("e"); + ck_tile::index_t batch_size = arg_parser.get_int("b"); + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); + + gemm_basic_args args; + args.p_a = a_buf.GetDeviceBuffer(); + args.p_b = b_buf.GetDeviceBuffer(); + args.p_c = c_buf.GetDeviceBuffer(); + args.epsilon = epsilon; + args.kbatch = batch_size; + args.M = M; + args.N = N; + args.K = K; + + // Only set stride_M and stride_N if they are non-zero and not equal to K. + if(stride_a != 0) + { + args.stride_A = stride_a; + } + else + { + args.stride_A = [&]() { + if constexpr(std::is_same_v) + { + return M; + } + else + { + return K; + } + }(); + } + + if(stride_b != 0) + { + args.stride_B = stride_b; + } + else + { + args.stride_B = [&]() { + if constexpr(std::is_same_v) + { + return N; + } + else + { + return K; + } + }(); + } + + if(stride_c != 0) + { + args.stride_C = stride_c; + } + else + { + args.stride_C = [&]() { + if constexpr(std::is_same_v) + { + return M; + } + else + { + return N; + } + }(); + } + + float ave_time = + gemm_calc(args, ck_tile::stream_config{nullptr, true}); + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "The overall perfomance of the GEMM with " + << "[" << data_type << "]" + << "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K + << "is: \n"; + std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n" + << std::flush; + + return ave_time; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + // The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N). + using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor; + using matrix_b_layout = ck_tile::tensor_layout::gemm::RowMajor; + using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor; + + // host verify + std::vector a_dimensions = + (std::is_same_v) + ? std::vector{M, K} + : std::vector{K, M}; + std::vector b_dimensions = + (std::is_same_v) + ? std::vector{N, K} + : std::vector{K, N}; + std::vector c_dimensions = + (std::is_same_v) + ? std::vector{M, N} + : std::vector{N, M}; + + ck_tile::HostTensor a_host(a_dimensions); + ck_tile::HostTensor b_host(b_dimensions); + + ck_tile::HostTensor c_host_ref(c_dimensions); + ck_tile::HostTensor c_host_dev(c_dimensions); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.data()); + b_buf.ToDevice(b_host.data()); + + invoke_gemm( + a_buf, b_buf, c_buf, arg_parser); + + bool pass = true; + + if(arg_parser.get_bool("v")) + { + // ToDo: Will Add the Element Op (bias) verification in the future. + ck_tile::reference_gemm(a_host, b_host, c_host_ref); + + c_buf.FromDevice(c_host_dev.data()); + + pass = ck_tile::check_err(c_host_dev, c_host_ref); + + std::cout << "The veification result is:" << (pass ? "correct" : "fail") << std::flush; + } + + std::cout << std::endl << std::flush; + + return !pass; +} diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp new file mode 100644 index 0000000000..28afb194c9 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -0,0 +1,71 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include + +template +struct GemmBasicTypeConfig; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; // type convert + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +using Types = GemmBasicTypeConfig; + +// Specific type aliases for easy access +using ADataType = Types::ADataType; +using BDataType = Types::BDataType; +using AccDataType = Types::AccDataType; +using CDataType = Types::CDataType; + +struct gemm_basic_args +{ + const void* p_a; + const void* p_b; + void* p_c; + float epsilon; + ck_tile::index_t kbatch; + ck_tile::index_t M; + ck_tile::index_t N; + ck_tile::index_t K; + ck_tile::index_t stride_A; + ck_tile::index_t stride_B; + ck_tile::index_t stride_C; +}; + +// host API +float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 995d193f10..3b4d1ca8be 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -4,3 +4,4 @@ include_directories(AFTER add_subdirectory(01_fmha) add_subdirectory(02_layernorm2d) +add_subdirectory(03_gemm) diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index efdaa23f3f..df2d719971 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -13,6 +13,9 @@ template @@ -24,7 +27,12 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, const ACCElementOp& acc_element_op = {}) { const int N = b_n_k.mDesc.get_lengths()[0]; - const int K = b_n_k.mDesc.get_lengths()[1]; + const int K = (std::is_same_v) + ? a_m_k.mDesc.get_lengths()[1] + : a_m_k.mDesc.get_lengths()[0]; + const int M = (std::is_same_v) + ? a_m_k.mDesc.get_lengths()[0] + : a_m_k.mDesc.get_lengths()[1]; auto f = [&](auto m) { for(int n = 0; n < N; ++n) @@ -33,7 +41,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, for(int k = 0; k < K; ++k) { - ADataType v_a = a_element_op(a_m_k(m, k)); + ADataType v_a = (std::is_same_v) + ? a_element_op(a_m_k(m, k)) + : a_element_op(a_m_k(k, m)); BDataType v_b = b_element_op(b_n_k(n, k)); v_acc += ck_tile::type_convert(v_a) * @@ -44,7 +54,6 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, } }; - make_ParallelTensorFunctor(f, - c_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 4143c34ff8..e6a71f210e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + using BlockGemmProblem = BlockGemmPipelineProblem< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::AccDataType, + TileGemmShape, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::QDataType, @@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + using BlockGemmProblem = BlockGemmPipelineProblem< + typename Problem::GemmDataType, + typename Problem::OGradDataType, + typename Problem::AccDataType, + TileGemmShape, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + using BlockGemmProblem = BlockGemmPipelineProblem< + typename Problem::OGradDataType, + typename Problem::VDataType, + typename Problem::AccDataType, + TileGemmShape, + typename Problem::BlockFmhaShape::Gemm2BlockWarps, + typename Problem::BlockFmhaShape::Gemm2WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher< typename Problem::OGradDataType, @@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + using BlockGemmProblem = BlockGemmPipelineProblem< + typename Problem::GemmDataType, + typename Problem::QDataType, + typename Problem::AccDataType, + TileGemmShape, + typename Problem::BlockFmhaShape::Gemm3BlockWarps, + typename Problem::BlockFmhaShape::Gemm3WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + using BlockGemmProblem = BlockGemmPipelineProblem< + typename Problem::GemmDataType, + typename Problem::KDataType, + typename Problem::AccDataType, + TileGemmShape, + typename Problem::BlockFmhaShape::Gemm4BlockWarps, + typename Problem::BlockFmhaShape::Gemm4WarpTile>>; using WarpGemm = WarpGemmMfmaDispatcher template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + using BlockGemmProblem = BlockGemmPipelineProblem< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::SaccDataType, + TileGemmShape, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -198,14 +199,15 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + using BlockGemmProblem = BlockGemmPipelineProblem< + typename Problem::QDataType, + typename Problem::KDataType, + typename Problem::SaccDataType, + TileGemmShape, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && @@ -952,14 +954,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { - using BlockGemmProblem = - BlockGemmPipelineProblem>; + using BlockGemmProblem = BlockGemmPipelineProblem< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + TileGemmShape, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; auto warp_gemm = [&]() { if constexpr(std::is_same_v && diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index dd313c5480..e9005462b0 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -21,6 +21,8 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp index f097790ae6..8d9e24638a 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp @@ -4,7 +4,8 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" namespace ck_tile { @@ -27,9 +28,9 @@ struct BlockGemmARegBGmemCRegV1 static constexpr index_t kBlockSize = Problem::kBlockSize; // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation - using BlockGemmARegBSmemCRegImpl = BlockGemmARegBSmemCRegV1< + using BlockGemmARegBGmemCRegImpl = BlockGemmARegBGmemCRegV1< BlockGemmProblem, - BlockGemmARegBSmemCRegV1DefaultPolicy>; + BlockGemmARegBGmemCRegV1DefaultPolicy>; CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() { @@ -82,7 +83,7 @@ struct BlockGemmARegBGmemCRegV1 block_sync_lds(); // block GEMM - BlockGemmARegBSmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window); + BlockGemmARegBGmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window); } // C = A * B @@ -128,7 +129,7 @@ struct BlockGemmARegBGmemCRegV1 block_sync_lds(); // block GEMM - return BlockGemmARegBSmemCRegImpl{}(a_block_tensor, b_block_smem_window); + return BlockGemmARegBGmemCRegImpl{}(a_block_tensor, b_block_smem_window); } }; diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index f798d6e815..8dd1d1ec28 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy { return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); } + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp new file mode 100644 index 0000000000..8cdf9b1005 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include + +#include + +namespace ck_tile { + +template +struct GemmKernel +{ + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + static constexpr index_t KernelBlockSize = GemmPipeline::KernelBlockSize; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CAccDataType = remove_cvref_t; + using CODataType = remove_cvref_t; + + __host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) + { + return TilePartitioner::GridSize(M_size, N_size, Batch_size); + } + + __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + struct GemmCommonKargs + { + const void* a_ptr; + const void* b_ptr; + void* c_ptr; + + float epsilon; + + ck_tile::index_t M; + ck_tile::index_t N; + ck_tile::index_t K; + ck_tile::index_t stride_A; + ck_tile::index_t stride_B; + ck_tile::index_t stride_C; + }; + + CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + float epsilon, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C) + { + return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, M, N, K, stride_A, stride_B, stride_C}; + } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const + { + const index_t i_m = TilePartitioner::iM; + const index_t i_n = TilePartitioner::iN; + // options + const ADataType* a_start = static_cast(kargs.a_ptr); + const BDataType* b_start = static_cast(kargs.b_ptr); + // Convert pointers to tensor views + auto a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_start, + make_tuple(kargs.M, kargs.K), + make_tuple(1, kargs.stride_A), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_start, + make_tuple(kargs.M, kargs.K), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + auto b_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + b_start, + make_tuple(kargs.N, kargs.K), + make_tuple(1, kargs.stride_B), + number{}, + number<1>{}); + } + else + { // Default NK layout + return make_naive_tensor_view( + b_start, + make_tuple(kargs.N, kargs.K), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + }(); + + auto ABlockWindow = make_tile_window( + a_tensor_view, + make_tuple(number{}, number{}), + {i_m, 0}); + + auto BBlockWindow = make_tile_window( + b_tensor_view, + make_tuple(number{}, number{}), + {i_n, 0}); + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK; + + auto acc = GemmPipeline{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr); + + CODataType* c_start = static_cast(kargs.c_ptr); + + auto c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_start, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_start, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + }(); + + auto CBlockWindow = make_tile_window( + c_tensor_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + // epilogue. + EpiloguePipeline{}(CBlockWindow, acc); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp new file mode 100644 index 0000000000..038d09ea35 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +template +struct GemmTilePartitioner +{ + using BlockGemmShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM = BlockGemmShape::kM; + static constexpr ck_tile::index_t kN = BlockGemmShape::kN; + static constexpr ck_tile::index_t kK = BlockGemmShape::kK; + + const index_t iM = __builtin_amdgcn_readfirstlane(i_tile_m * kM); + const index_t iN = __builtin_amdgcn_readfirstlane(i_tile_n * kN); + + CK_TILE_HOST static constexpr auto + GridSize(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t batch_size) + { + ck_tile::index_t GridDimX = (M + kM - 1) / kM; + ck_tile::index_t GridDimY = (N + kN - 1) / kN; + ck_tile::index_t GridDimZ = batch_size; + return dim3(GridDimX, GridDimY, GridDimZ); + } + + CK_TILE_DEVICE auto operator()() + { + const index_t i_GridDimX = blockIdx.x; + const index_t i_GridDimY = blockIdx.y; + const index_t i_GridDimZ = blockIdx.z; + return ck_tile::make_tuple(i_GridDimX, i_GridDimY, i_GridDimZ); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp index 29b03e2828..a90178ddb1 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" namespace ck_tile { @@ -18,12 +19,16 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t KernelBlockSize = Problem::KernelBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; + static constexpr index_t AlignmentA = Problem::AlignmentA; + static constexpr index_t AlignmentB = Problem::AlignmentB; + static constexpr index_t AlignmentC = Problem::AlignmentC; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() { return ck_tile::integer_divide_ceil( @@ -35,6 +40,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); } + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + template 0) { // global read i + 1 a_block_tile = load_tile(a_copy_dram_window); @@ -167,8 +176,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 store_tile(b_copy_lds_window, b_block_tile_tmp); iCounter--; - - } while(iCounter > 0); + } // tail { diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index f706900013..c7f292d2b5 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -91,6 +91,33 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy return b_lds_block_desc; } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB() + { + constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * + MakeBLdsBlockDescriptor().get_element_space_size(); + return smem_size_b; + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + index_t smem_size = 0; + smem_size += smem_size_a + smem_size_b; + + return smem_size; + } #elif 1 // fake XOR template @@ -168,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy { using ADataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t KernelBlockSize = Problem::KernelBlockSize; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; @@ -177,7 +204,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t K0 = kKPerBlock / K1; constexpr index_t M2 = get_warp_size() / K0; #if 1 // coalesce reading for each blocks - constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M1 = KernelBlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t M0 = kMPerBlock / (M2 * M1); return make_static_tile_distribution( @@ -188,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy sequence<1, 2>, sequence<0, 1>>{}); #else // coalesce reading for each warps - constexpr index_t M0 = kBlockSize / get_warp_size(); + constexpr index_t M0 = KernelBlockSize / get_warp_size(); constexpr index_t M1 = kMPerBlock / (M2 * M0); return make_static_tile_distribution( @@ -206,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy { using BDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t KernelBlockSize = Problem::KernelBlockSize; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; @@ -215,7 +244,9 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy constexpr index_t K0 = kKPerBlock / K1; constexpr index_t N2 = get_warp_size() / K0; #if 1 // coalesce reading for each blocks - constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N1 = KernelBlockSize / get_warp_size(); + static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t N0 = kNPerBlock / (N2 * N1); return make_static_tile_distribution( @@ -226,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy sequence<1, 2>, sequence<0, 1>>{}); #else // coalesce reading for each warps - constexpr index_t N0 = kBlockSize / get_warp_size(); + constexpr index_t N0 = KernelBlockSize / get_warp_size(); constexpr index_t N1 = kNPerBlock / (N2 * N0); return make_static_tile_distribution( diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp index ef0611dd60..deb9b07f16 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" namespace ck_tile { @@ -18,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t KernelBlockSize = Problem::KernelBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp index 62165ebce2..dda6022dc8 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp @@ -5,13 +5,17 @@ #include "ck_tile/core.hpp" +#define VectorLoadSize 16 + namespace ck_tile { template + typename BlockGemmShape_, + bool kPadA_ = false, + bool kPadB_ = false, + bool kPadC_ = false> struct BlockGemmPipelineProblem { using ADataType = remove_cvref_t; @@ -19,7 +23,14 @@ struct BlockGemmPipelineProblem using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t KernelBlockSize = BlockGemmShape::NumWarps * get_warp_size(); + static constexpr bool kPadA = kPadA_; + static constexpr bool kPadB = kPadB_; + static constexpr bool kPadC = kPadC_; + + static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1; + static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1; + static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index f3c4d8bf67..2522abe5ed 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -7,12 +7,18 @@ namespace ck_tile { -template +template struct TileGemmShape { - static constexpr index_t kM = kMPerTile; - static constexpr index_t kN = kNPerTile; - static constexpr index_t kK = kKPerTile; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + + static constexpr index_t kM = BlockTile::at(number<0>{}); + static constexpr index_t kN = BlockTile::at(number<1>{}); + static constexpr index_t kK = BlockTile::at(number<2>{}); }; } // namespace ck_tile