diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 03fc9c7eb1..8ae46cadc6 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,2 +1,2 @@ -set(CMAKE_BUILD_TYPE Debug) -add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) \ No newline at end of file +add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) +add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 569afed256..09427217c5 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "gemm_basic.hpp" #include #include @@ -10,51 +9,48 @@ #include #include -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("b", "1", "batch size") - .insert("m", "1024", "m dimension") - .insert("n", "2048", "n dimension") - .insert("k", "64", "k dimension") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("e", "1e-5", "Absolute error tolerance") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("warmup", "10", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "gemm_basic.hpp" - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -template +template float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) { // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. constexpr bool kPadA = true; constexpr bool kPadB = true; + constexpr bool kPadC = true; constexpr bool kTilePermute = false; + // The rank and permutation will also be generate out by the CodeGen part. + constexpr ck_tile::index_t kOutputRank = 2; constexpr int kBlockPerCu = 1; - using TilePartitioner = ck_tile::GemmTilePartitioner; + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; - // The rank and permutation will also be generate out by the CodeGen part. - constexpr ck_tile::index_t kOutputRank = 2; + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; // Whether doing the CShuffle (transpose before the global memory), depending on the output // layout. constexpr bool CShuffleEpilogue = - std::is_same_v; + std::is_same_v; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTilePartitioner; using GemmEpilogue = std::conditional_t< CShuffleEpilogue, @@ -70,14 +66,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) TilePartitioner::kN>>, ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem>>; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; + using CodegenGemmPipeline = + ck_tile::GemmPipelineAGmemBGmemCRegV1; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKargs(args.p_a, args.p_b, args.p_c, - args.epsilon, args.M, args.N, args.K, @@ -88,299 +91,20 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); constexpr dim3 blocks = Kernel::BlockSize(); + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + float ave_time = ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } -template -float invoke_gemm(ck_tile::DeviceMem& a_buf, - ck_tile::DeviceMem& b_buf, - ck_tile::DeviceMem& c_buf, - const ck_tile::ArgParser& arg_parser) -{ +#include "run_gemm_example.inc" - std::string data_type = arg_parser.get_str("prec"); - - if(data_type != DataTypeTraits::name) - { - std::cerr << "Data type mismatch: expected " << DataTypeTraits::name << ", got " - << data_type << std::endl; - return -1; // Or handle the error appropriately - } - - float epsilon = arg_parser.get_float("e"); - ck_tile::index_t batch_size = arg_parser.get_int("b"); - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); - - ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); - - gemm_basic_args args; - args.p_a = a_buf.GetDeviceBuffer(); - args.p_b = b_buf.GetDeviceBuffer(); - args.p_c = c_buf.GetDeviceBuffer(); - args.epsilon = epsilon; - args.kbatch = batch_size; - args.M = M; - args.N = N; - args.K = K; - - // Only set stride_M and stride_N if they are non-zero and not equal to K. - if(stride_a != 0) - { - args.stride_A = stride_a; - } - else - { - args.stride_A = [&]() { - if constexpr(std::is_same_v) - { - return M; - } - else - { - return K; - } - }(); - } - - if(stride_b != 0) - { - args.stride_B = stride_b; - } - else - { - args.stride_B = [&]() { - if constexpr(std::is_same_v) - { - return N; - } - else - { - return K; - } - }(); - } - - if(stride_c != 0) - { - args.stride_C = stride_c; - } - else - { - args.stride_C = [&]() { - if constexpr(std::is_same_v) - { - return M; - } - else - { - return N; - } - }(); - } - - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true}); - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "The overall perfomance of the GEMM with " - << "[" << data_type << "]" - << "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K - << " is: \n"; - std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n" - << std::flush; - - return ave_time; -} - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); - - // The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N). - using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor; - using matrix_b_layout = ck_tile::tensor_layout::gemm::ColumnMajor; - using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor; - - // host verify - std::vector a_dimensions = - (std::is_same_v) - ? std::vector{M, K} - : std::vector{K, M}; - std::vector b_dimensions = - (std::is_same_v) - ? std::vector{N, K} - : std::vector{K, N}; - std::vector c_dimensions = - (std::is_same_v) - ? std::vector{M, N} - : std::vector{N, M}; - - ck_tile::HostTensor a_host(a_dimensions); - ck_tile::HostTensor b_host(b_dimensions); - - ck_tile::HostTensor c_host_ref(c_dimensions); - ck_tile::HostTensor c_host_dev(c_dimensions); - - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); - - ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); - - a_buf.ToDevice(a_host.data()); - b_buf.ToDevice(b_host.data()); - - // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadA = true; - constexpr bool kPadB = true; - constexpr bool kPadC = true; - - // This part comes from the Codegen - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - - using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - using CodegenGemmTraits = ck_tile:: - TileGemmTraits; - - using CodegenPipelineProblem = ck_tile:: - GemmPipelineProblem; - - using CodegenGemmPolicy = ck_tile:: - UniversalGemmPipelineAgBgCrPolicy; - - using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1; - - invoke_gemm(a_buf, b_buf, c_buf, arg_parser); - - c_buf.FromDevice(c_host_dev.data()); - - bool pass_cpu = true; - - if(arg_parser.get_int("v") == 1) - { - // ToDo: Will Add the Element Op (bias) verification in the future. - ck_tile::reference_gemm(a_host, b_host, c_host_ref); - - pass_cpu = ck_tile::check_err(c_host_dev, c_host_ref); - - std::cout << "The CPU veification result is:" << (pass_cpu ? "correct" : "fail") - << std::flush; - } - - bool pass_gpu = true; - - if(arg_parser.get_int("v") == 2) - { - ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); - - if(stride_a == 0) - { - if constexpr(std::is_same_v) - { - stride_a = M; - } - else - { - stride_a = K; - } - } - - if(stride_b == 0) - { - if constexpr(std::is_same_v) - { - stride_b = N; - } - else - { - stride_b = K; - } - } - - if(stride_c == 0) - { - if constexpr(std::is_same_v) - { - stride_c = M; - } - else - { - stride_c = N; - } - } - - ck_tile::HostTensor c_host_gpu_ref(c_dimensions); - ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); - - ck_tile::reference_gemm_gpu( - a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c); - - c_buf.FromDevice(c_host_gpu_ref.data()); - - pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref); - - std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail") - << std::flush; - } - - std::cout << std::endl << std::flush; - - return !pass_gpu; -} +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index ce2e0f706d..23e99bc2a8 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -4,12 +4,10 @@ #pragma once +#include + #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/host.hpp" -#include template struct GemmBasicTypeConfig; @@ -20,7 +18,7 @@ struct GemmBasicTypeConfig using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using AccDataType = float; - using CDataType = ck_tile::half_t; // type convert + using CDataType = ck_tile::half_t; // ToDo: Add more bias config to support different categories of GEMM. }; @@ -58,7 +56,6 @@ struct gemm_basic_args const void* p_a; const void* p_b; void* p_c; - float epsilon; ck_tile::index_t kbatch; ck_tile::index_t M; ck_tile::index_t N; @@ -68,5 +65,28 @@ struct gemm_basic_args ck_tile::index_t stride_C; }; +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("b", "1", "batch size") + .insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + // host API float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp new file mode 100644 index 0000000000..2ee0395e47 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "gemm_basic.hpp" + +template +float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) +{ + // ToDo: This will be modified by the codegen code later. + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadA = true; + constexpr bool kPadB = true; + constexpr bool kPadC = true; + + constexpr int kBlockPerCu = 1; + + // =============================================== + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTilePartitioner; + + using GemmEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>; + + using Traits = ck_tile::TileGemmTraits; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< + ck_tile::GemmPipelineProblem>; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< + ck_tile::UniversalGemmPipelineProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKargs(args.p_a, + args.p_b, + args.p_c, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + args.stride_C); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) + { + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc new file mode 100644 index 0000000000..8db131738b --- /dev/null +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + gemm_basic_args args; + args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); + args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); + args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); + args.kbatch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + + float ave_time = gemm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Gemm{MemBoundPipeline}"}; + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t batch_size = arg_parser.get_int("b"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, a_layout); + stride_B = f_get_default_stride(K, N, stride_B, b_layout); + stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, a_layout)); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, b_layout)); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + + // TODO: add different init types + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_size, + n_warmup, + n_repeat); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); + + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + ck_tile::HostTensor c_m_n_gpu_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ck_tile::reference_gemm_gpu( + a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); + + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index fa4b8d3cc4..2c423831e1 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -56,6 +56,7 @@ #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/ignore.hpp" +#include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/random.hpp" diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 06b5a8da0b..f150fc54ca 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -46,6 +46,31 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear{}, bool_constant{}); } +template +CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, + const tile_window_with_static_distribution& tile_window, + bool_constant = {}) +{ + return tile_window.load(dst_tile, bool_constant{}); +} + +/** + * @brief Loads a tile of data using inline assembly. + * + * @note Bare in mind that loading data this way, you have to manually initialize your + * thread buffer and synchronize load afterwards in order to make sure it's done before + * using loaded data from registers + * @see `tile_window_with_static_distribution::init_raw()` and `buffer_view.hpp` + * @see `buffer_load_fence()` + */ template = {}, bool_constant = {}) const { - using Traits = load_store_traits; + constexpr auto tile_dstr = TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + load(dst_tensor, bool_constant{}); + return dst_tensor; + } + template + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + bool_constant = {}) const + { + using Traits = load_store_traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = TileDstr{}; - auto dst_tensor = make_static_distributed_tensor(tile_dstr); - // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { /// TODO: use structure binding (to be captured later) if compiled in C++20 @@ -353,8 +360,6 @@ struct tile_window_with_static_distribution } }); }); - - return dst_tensor; } template + +namespace ck_tile { +namespace literals { +// [P0330] Literal Suffix for (signed) size_t (C++23) +// ref: https://wg21.link/p0330r8 +inline constexpr std::size_t operator""_uz(unsigned long long size) +{ + return static_cast(size); +} + +inline constexpr std::size_t operator""_zu(unsigned long long size) +{ + return static_cast(size); +} +} // namespace literals +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index a496c91e00..dbdef0e9c7 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -1,12 +1,13 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include +#include + #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include namespace ck_tile { @@ -14,55 +15,36 @@ template CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, - const HostTensor& b_n_k, + const HostTensor& b_k_n, HostTensor& c_m_n, const AElementOp& a_element_op = {}, const BElementOp& b_element_op = {}, const ACCElementOp& acc_element_op = {}) { - const int N = (std::is_same_v) - ? b_n_k.mDesc.get_lengths()[0] - : b_n_k.mDesc.get_lengths()[1]; - const int K = (std::is_same_v) - ? a_m_k.mDesc.get_lengths()[1] - : a_m_k.mDesc.get_lengths()[0]; - const int M = (std::is_same_v) - ? a_m_k.mDesc.get_lengths()[0] - : a_m_k.mDesc.get_lengths()[1]; + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); - auto f = [&](auto m) { - for(int n = 0; n < N; ++n) + auto f_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + + for(std::size_t k = 0; k < K; ++k) { - AccDataType v_acc = 0; + ADataType v_a = a_element_op(a_m_k(m, k)); + BDataType v_b = b_element_op(b_k_n(k, n)); - for(int k = 0; k < K; ++k) - { - ADataType v_a = (std::is_same_v) - ? a_element_op(a_m_k(m, k)) - : a_element_op(a_m_k(k, m)); - BDataType v_b = (std::is_same_v) - ? b_element_op(b_n_k(n, k)) - : b_element_op(b_n_k(k, n)); - - v_acc += ck_tile::type_convert(v_a) * - ck_tile::type_convert(v_b); - } - - CDataType& c_ref = (std::is_same_v) - ? c_m_n(m, n) - : c_m_n(n, m); - c_ref = ck_tile::type_convert(acc_element_op(v_acc)); + v_acc += + ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); } + + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); }; - make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } template , BlockGemmARegBGmemCRegV1DefaultPolicy>; - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() { return sizeof(BDataType) * Policy::template MakeBSmemBlockDescriptor().get_element_space_size(); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index dc0b411352..d6fee879b1 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1 static constexpr index_t kBlockSize = Problem::kBlockSize; // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockWindowTmp& a_block_window_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + const ABlockWindow& a_block_window, + const BBlockWindow& b_block_window) const { - static_assert(std::is_same_v && - std::is_same_v && + static_assert(std::is_same_v && + std::is_same_v && std::is_same_v, "wrong!"); - constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && KPerBlock == BlockGemmShape::kK, @@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1 // construct A-warp-window auto a_warp_window_tmp = make_tile_window( - a_block_window_tmp.get_bottom_tensor_view(), + a_block_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), - a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); #if 0 // FIXME: using array will cause register spill @@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1 // construct B-warp-window auto b_warp_window_tmp = make_tile_window( - b_block_window_tmp.get_bottom_tensor_view(), + b_block_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), - b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); #if 0 // FIXME: using array will cause register spill @@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1 } // C = A * B - template + template CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, - const BBlockWindowTmp& b_block_window_tmp) const + const BBlockWindow& b_block_window) const { auto c_block_tensor = MakeCBlockTile(); - operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window); return c_block_tensor; } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 48329c8ba5..1671ddad3e 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -3,11 +3,12 @@ #pragma once +#include +#include + #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" -#include - -#include +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { @@ -17,20 +18,19 @@ struct GemmKernel using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CAccDataType = remove_cvref_t; - using CODataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + // using CAccDataType = remove_cvref_t; + using CDataType = remove_cvref_t; - using LayoutA = remove_cvref_t; - using LayoutB = remove_cvref_t; - using LayoutC = remove_cvref_t; - - __host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size) + __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { - return TilePartitioner::GridSize(M_size, N_size, Batch_size); + return TilePartitioner::GridSize(M, N, KBatch); } __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } @@ -40,34 +40,30 @@ struct GemmKernel const void* a_ptr; const void* b_ptr; void* c_ptr; - - float epsilon; - - ck_tile::index_t M; - ck_tile::index_t N; - ck_tile::index_t K; - ck_tile::index_t stride_A; - ck_tile::index_t stride_B; - ck_tile::index_t stride_C; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; }; CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr, const void* b_ptr, void* c_ptr, - float epsilon, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C) + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + index_t stride_C) { - return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, M, N, K, stride_A, stride_B, stride_C}; + return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; } - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const @@ -78,43 +74,43 @@ struct GemmKernel const BDataType* b_start = static_cast(kargs.b_ptr); // Convert pointers to tensor views auto a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_start, - make_tuple(kargs.M, kargs.K), - make_tuple(1, kargs.stride_A), - number{}, - number<1>{}); - } - else + if constexpr(std::is_same_v) { return make_naive_tensor_view( a_start, make_tuple(kargs.M, kargs.K), make_tuple(kargs.stride_A, 1), - number{}, + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_start, + make_tuple(kargs.M, kargs.K), + make_tuple(1, kargs.stride_A), + number<1>{}, number<1>{}); } }(); auto b_tensor_view = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return make_naive_tensor_view( b_start, make_tuple(kargs.N, kargs.K), make_tuple(1, kargs.stride_B), - number{}, + number<1>{}, number<1>{}); } else - { // Default NK layout + { return make_naive_tensor_view( b_start, make_tuple(kargs.N, kargs.K), make_tuple(kargs.stride_B, 1), - number{}, + number{}, number<1>{}); } }(); @@ -122,10 +118,12 @@ struct GemmKernel auto a_pad_view = pad_tensor_view( a_tensor_view, make_tuple(number{}, number{}), - sequence < 0, - GemmPipeline::kPadA ? 1 : 0 > {}); + // somehow clang-format is splitting below line into multiple. + // clang-format off + sequence{}); + // clang-format on - auto ABlockWindow = make_tile_window( + auto a_block_window = make_tile_window( a_pad_view, make_tuple(number{}, number{}), {i_m, 0}); @@ -133,10 +131,11 @@ struct GemmKernel auto b_pad_view = pad_tensor_view( b_tensor_view, make_tuple(number{}, number{}), - sequence < 0, - GemmPipeline::kPadB ? 1 : 0 > {}); + // clang-format off + sequence{}); + // clang-format on - auto BBlockWindow = make_tile_window( + auto b_block_window = make_tile_window( b_pad_view, make_tuple(number{}, number{}), {i_n, 0}); @@ -144,20 +143,21 @@ struct GemmKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK; + const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); - auto acc = GemmPipeline{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr); - - CODataType* c_start = static_cast(kargs.c_ptr); + // Run GEMM cooperatively by whole wokrgroup. + auto c_block_tile = + GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); + CDataType* c_start = static_cast(kargs.c_ptr); auto c_tensor_view = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return make_naive_tensor_view( c_start, make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), - number{}, + make_tuple(kargs.stride_C, 1), + number{}, number<1>{}); } else @@ -165,8 +165,8 @@ struct GemmKernel return make_naive_tensor_view( c_start, make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), - number{}, + make_tuple(1, kargs.stride_C), + number<1>{}, number<1>{}); } }(); @@ -174,14 +174,15 @@ struct GemmKernel auto c_pad_view = pad_tensor_view( c_tensor_view, make_tuple(number{}, number{}), - sequence < 0, - GemmPipeline::kPadC ? 1 : 0 > {}); - auto CBlockWindow_pad = make_tile_window( + // clang-format off + sequence{}); + // clang-format on + auto c_block_window = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - EpiloguePipeline{}(CBlockWindow_pad, acc); + EpiloguePipeline{}(c_block_window, c_block_tile); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index a49ffc2911..6387233c0f 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -9,26 +9,30 @@ namespace ck_tile { template struct GemmTilePartitioner { - using BlockGemmShape = ck_tile::remove_cvref_t; + using BlockGemmShape = remove_cvref_t; - static constexpr ck_tile::index_t kM = BlockGemmShape::kM; - static constexpr ck_tile::index_t kN = BlockGemmShape::kN; - static constexpr ck_tile::index_t kK = BlockGemmShape::kK; + static constexpr index_t kM = BlockGemmShape::kM; + static constexpr index_t kN = BlockGemmShape::kN; + static constexpr index_t kK = BlockGemmShape::kK; - CK_TILE_HOST static constexpr auto - GridSize(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t batch_size) + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t batch_size) { - ck_tile::index_t GridDimX = (M + kM - 1) / kM; - ck_tile::index_t GridDimY = (N + kN - 1) / kN; - ck_tile::index_t GridDimZ = batch_size; + index_t GridDimX = (M + kM - 1) / kM; + index_t GridDimY = (N + kN - 1) / kN; + index_t GridDimZ = batch_size; return dim3(GridDimX, GridDimY, GridDimZ); } + CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K) + { + return integer_divide_ceil(K, kK); + } + CK_TILE_DEVICE auto operator()() { const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM); const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN); - return ck_tile::make_tuple(iM, iN); + return make_tuple(iM, iN); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp new file mode 100644 index 0000000000..b9b45d3f42 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -0,0 +1,413 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrMem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + // TODO: Is this 32K value gfx9 arch specific? + static constexpr index_t MinMemInFlyBytes = 32768; + + static constexpr index_t WgpPerCU = + (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil( + MinMemInFlyBytes / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t LocalPrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } +}; + +// Maximum Global Memory throughput pipeline with >=32KB data in fly +// GlobalPrefetchStages: >=2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 +template +struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem +{ + using Base = BaseGemmPipelineAgBgCrMem; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t VectorSizeA = Problem::VectorSizeA; + static constexpr index_t VectorSizeB = Problem::VectorSizeB; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; + + static constexpr bool kPadA = Problem::kPadA; + static constexpr bool kPadB = Problem::kPadB; + static constexpr bool kPadC = Problem::kPadC; + + // Where is the right place for HasHotLoop and TailNum ??? + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + CK_TILE_HOST_DEVICE constexpr index_t GetStaticLdsSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + struct PipelineImpl + { + }; + + template <> + struct PipelineImpl + { + template + CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, + SrcTileWindow& dram_tile_window) const + { + load_tile(dst_block_tile, dram_tile_window); + move_tile_window(dram_tile_window, {0, KPerBlock}); + } + + template + CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, + const SrcBlockTile& src_block_tile, + const ElementFunction& element_func) const + { + const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile); + store_tile(lds_tile_window, block_tile_tmp); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "A/B Dram block window should have the same data type as appropriate " + "([A|B]DataType) defined in Problem definition!"); + + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + NPerBlock == + BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" + " or KPerBlock!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + // TODO: LDS alignment should come from Policy! + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + constexpr auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + tuple_array a_block_tiles; + tuple_array b_block_tiles; + + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // prefetch + // global read 0 + GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); + GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + + // Global prefetch [1, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { + GlobalPrefetch(a_block_tiles.get(number{}), a_copy_dram_window); + GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window); + }); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + // block_gemm.LocalPrefetch(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + LocalPrefill( + a_copy_lds_window, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + a_element_func); + LocalPrefill( + b_copy_lds_window, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + b_element_func); + + GlobalPrefetch(a_block_tiles.get(number{}), + a_copy_dram_window); + GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + auto HotLoopTail = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + + // block_gemm.LocalPrefetch(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{}), + a_element_func); + LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{}), + b_element_func); + }); + + block_sync_lds(); + // block_gemm.LocalPrefetch(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + // block_gemm.LocalPrefetch(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + else if constexpr(TailNum == TailNumber::Two) + { + HotLoopTail(number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + HotLoopTail(number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + HotLoopTail(number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + HotLoopTail(number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + HotLoopTail(number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + HotLoopTail(number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + HotLoopTail(number{}); + } + + return c_block_tile; + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp new file mode 100644 index 0000000000..5e93ca21c0 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +enum struct GemmPipelineScheduler +{ + Intrawave, + Interwave, +}; + +enum struct TailNumber +{ + // Single / Double buffer pipeline + Odd, + Even, + + // Long prefetch pipeline, up to 8 + One, + Two, + Three, + Four, + Five, + Six, + Seven, + + // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages + Empty, + // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add + // prefetchstages + Full, +}; + +} // namespace ck_tile + +inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineScheduler& s) +{ + switch(s) + { + case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break; + case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break; + default: os << ""; + } + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const ck_tile::TailNumber& s) +{ + switch(s) + { + case ck_tile::TailNumber::Odd: os << "Odd"; break; + case ck_tile::TailNumber::Even: os << "Even"; break; + case ck_tile::TailNumber::One: os << "One"; break; + case ck_tile::TailNumber::Two: os << "Two"; break; + case ck_tile::TailNumber::Three: os << "Three"; break; + case ck_tile::TailNumber::Four: os << "Four"; break; + case ck_tile::TailNumber::Five: os << "Five"; break; + case ck_tile::TailNumber::Six: os << "Six"; break; + case ck_tile::TailNumber::Seven: os << "Seven"; break; + case ck_tile::TailNumber::Empty: os << "Empty"; break; + case ck_tile::TailNumber::Full: os << "Full"; break; + default: os << ""; + } + return os; +} diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 5ed7d036ea..a2424290e6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,27 +19,27 @@ struct GemmPipelineAGmemBGmemCRegV1 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - static constexpr index_t kBlockSize = Problem::kBlockSize; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; - static constexpr index_t AlignmentA = Problem::AlignmentA; - static constexpr index_t AlignmentB = Problem::AlignmentB; - static constexpr index_t AlignmentC = Problem::AlignmentC; + static constexpr index_t VectorSizeA = Problem::VectorSizeA; + static constexpr index_t VectorSizeB = Problem::VectorSizeB; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr bool kPadA = Problem::kPadA; static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadC = Problem::kPadC; - using LayoutA = remove_cvref_t; - using LayoutB = remove_cvref_t; - using LayoutC = remove_cvref_t; - - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() { - return ck_tile::integer_divide_ceil( + return integer_divide_ceil( sizeof(ADataType) * Policy::template MakeALdsBlockDescriptor().get_element_space_size(), 16) * @@ -48,7 +48,7 @@ struct GemmPipelineAGmemBGmemCRegV1 Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); } - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Policy::template GetSmemSize(); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 8639f00fbb..199ba56aac 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -71,8 +71,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using namespace ck_tile; - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; @@ -93,7 +91,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * MakeALdsBlockDescriptor().get_element_space_size(); @@ -101,7 +99,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() { constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * MakeBLdsBlockDescriptor().get_element_space_size(); @@ -109,7 +107,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { constexpr index_t smem_size_a = GetSmemSizeA(); constexpr index_t smem_size_b = GetSmemSizeB(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index bff7fc0a0e..96a5a61c8b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -25,9 +25,9 @@ struct GemmPipelineAGmemBGmemCRegV2 static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() { - return ck_tile::integer_divide_ceil( + return integer_divide_ceil( sizeof(ADataType) * Policy::template MakeALdsBlockDescriptor().get_element_space_size(), 16) * diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index d7b3b24a4a..1156f549b6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -1,14 +1,15 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" - -#define VectorLoadSize 16 +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { +static constexpr int _VectorSize = 16; + template ; using GemmTraits = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr bool kPadA = GemmTraits::kPadA; static constexpr bool kPadB = GemmTraits::kPadB; static constexpr bool kPadC = GemmTraits::kPadC; - using LayoutA = remove_cvref_t; - using LayoutB = remove_cvref_t; - using LayoutC = remove_cvref_t; + static constexpr index_t VectorSizeA = kPadA ? 1 : _VectorSize / sizeof(ADataType); + static constexpr index_t VectorSizeB = kPadB ? 1 : _VectorSize / sizeof(BDataType); + static constexpr index_t VectorSizeC = kPadC ? 1 : _VectorSize / sizeof(CDataType); +}; - static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType); - static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType); - static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType); +template +struct UniversalGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using GemmTraits = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); + + static constexpr bool kPadA = GemmTraits::kPadA; + static constexpr bool kPadB = GemmTraits::kPadB; + static constexpr bool kPadC = GemmTraits::kPadC; + + static constexpr index_t VectorSizeA = kPadA ? _VectorSize / sizeof(ADataType) : 1; + static constexpr index_t VectorSizeB = kPadB ? _VectorSize / sizeof(BDataType) : 1; + static constexpr index_t VectorSizeC = kPadC ? _VectorSize / sizeof(CDataType) : 1; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 98da1510c7..9d050be2fb 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -1,27 +1,25 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck_tile/core.hpp" - namespace ck_tile { template + typename ALayout_, + typename BLayout_, + typename CLayout_> struct TileGemmTraits { static constexpr bool kPadA = kPadA_; static constexpr bool kPadB = kPadB_; static constexpr bool kPadC = kPadC_; - using LayoutA = LayoutA_; - using LayoutB = LayoutB_; - using LayoutC = LayoutC_; + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index dd164e72ea..bb59a72982 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -39,9 +39,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 #if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); #else - ck_tile::ignore = c_vec; - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = c_vec; + ignore = a_vec; + ignore = b_vec; #endif } @@ -52,8 +52,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #else - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = a_vec; + ignore = b_vec; return CVecType{0.f}; #endif } @@ -90,9 +90,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 #if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); #else - ck_tile::ignore = c_vec; - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = c_vec; + ignore = a_vec; + ignore = b_vec; #endif } @@ -103,8 +103,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #else - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = a_vec; + ignore = b_vec; return CVecType{0.f}; #endif } @@ -154,9 +154,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 0); }); #else - ck_tile::ignore = c_vec; - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = c_vec; + ignore = a_vec; + ignore = b_vec; #endif } @@ -181,8 +181,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 }); return c_vec; #else - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = a_vec; + ignore = b_vec; return CVecType{0.f}; #endif } @@ -231,9 +231,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 0); }); #else - ck_tile::ignore = c_vec; - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = c_vec; + ignore = a_vec; + ignore = b_vec; #endif } @@ -258,8 +258,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 }); return c_vec; #else - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = a_vec; + ignore = b_vec; return CVecType{0.f}; #endif } @@ -320,9 +320,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); }); #else - ck_tile::ignore = c_vec; - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = c_vec; + ignore = a_vec; + ignore = b_vec; #endif } @@ -356,8 +356,8 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base }); return c_vec; #else - ck_tile::ignore = a_vec; - ck_tile::ignore = b_vec; + ignore = a_vec; + ignore = b_vec; return CVecType{0.f}; #endif } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 99cd5d787e..4183d9cb95 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher; // clang-format off // fp16 -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; // bf16 -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; // fp8 -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // clang-format on } // namespace impl diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 9075ca2ed0..ac9c4311df 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(image_to_column) +add_subdirectory(gemm) diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt new file mode 100644 index 0000000000..f96ad9c6e0 --- /dev/null +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -0,0 +1,4 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_ck_tile_gemm_mem_pipeline test_gemm_mem_pipeline.cpp) +endif() diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp b/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp new file mode 100644 index 0000000000..f72a80b5a3 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_mem_pipeline_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType + std::tuple< Row, Col, Row, F16, F16, F32, F16>, + std::tuple< Col, Row, Row, F16, F16, F32, F16>, + std::tuple< Row, Row, Row, F16, F16, F32, F16>, + std::tuple< Col, Col, Row, F16, F16, F32, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes); + +#include "test_gemm_mem_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc new file mode 100644 index 0000000000..b26114f39d --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc @@ -0,0 +1,41 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 1024; + constexpr int K = 320; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 1024; + constexpr int K = 320; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 1024; + constexpr int K = 432; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestCkTileGemmMemPipeline, Regular) +{ + std::vector Ms{512}; + constexpr int N = 1024; + constexpr int K = 512; + + for(int M : Ms) + this->Run(M, N, K); +} diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp new file mode 100644 index 0000000000..1b243ab437 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp @@ -0,0 +1,318 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +template +class TestCkTileGemmMemPipeline : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + // TODO: expose tile size through test t-param ? + + struct gemm_basic_args + { + const void* p_a; + const void* p_b; + void* p_c; + ck_tile::index_t kbatch; + ck_tile::index_t M; + ck_tile::index_t N; + ck_tile::index_t K; + ck_tile::index_t stride_A; + ck_tile::index_t stride_B; + ck_tile::index_t stride_C; + }; + + void invoke_gemm(const gemm_basic_args& args, const ck_tile::stream_config& s) + { + // TODO: This should be parameterized in tests + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + constexpr bool kPadA = true; + constexpr bool kPadB = true; + constexpr bool kPadC = true; + + constexpr int kBlockPerCu = 1; + + // =============================================== + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTilePartitioner; + + using GemmEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>; + + using Traits = ck_tile::TileGemmTraits; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< + ck_tile::GemmPipelineProblem>; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< + ck_tile::UniversalGemmPipelineProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKargs(args.p_a, + args.p_b, + args.p_c, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + args.stride_C); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(s.log_level_ > 0) + { + std::cout << "Lunching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + }; + + if(has_hot_loop) + { + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + } + + public: + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA = 0, + const int StrideB = 0, + const int StrideC = 0) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + int kbatch = 1) + { + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + std::size_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); + std::size_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); + std::size_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + + ck_tile::FillUniformDistributionIntegerValue{-5, 5}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + gemm_basic_args args; + args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); + args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); + args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); + args.kbatch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + + invoke_gemm(args, ck_tile::stream_config{nullptr, false}); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); + EXPECT_TRUE(pass); + } +};