From bb1298c699e4a1207df893a987019b80044a2faf Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 10 Oct 2024 11:48:43 +0000 Subject: [PATCH] Refactor gemm examples. --- example/ck_tile/03_gemm/gemm_basic.cpp | 280 ++---------------- example/ck_tile/03_gemm/gemm_basic.hpp | 20 ++ .../03_gemm/gemm_basic_mem_pipeline.cpp | 194 +----------- example/ck_tile/03_gemm/run_gemm_example.inc | 177 +++++++++++ 4 files changed, 230 insertions(+), 441 deletions(-) create mode 100644 example/ck_tile/03_gemm/run_gemm_example.inc diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index a81ea3fdd2..c65e4174ee 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -15,223 +15,16 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("b", "1", "batch size") - .insert("m", "1024", "m dimension") - .insert("n", "2048", "n dimension") - .insert("k", "64", "k dimension") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("warmup", "10", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -template +template float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) { // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. constexpr bool kPadA = true; constexpr bool kPadB = true; + constexpr bool kPadC = true; constexpr int kBlockPerCu = 1; - using TilePartitioner = ck_tile::GemmTilePartitioner; - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); - constexpr dim3 blocks = Kernel::BlockSize(); - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; -} - -template -float invoke_gemm(ck_tile::DeviceMem& a_buf, - ck_tile::DeviceMem& b_buf, - ck_tile::DeviceMem& c_buf, - const ck_tile::ArgParser& arg_parser) -{ - - std::string data_type = arg_parser.get_str("prec"); - - if(data_type != DataTypeTraits::name) - { - std::cerr << "Data type mismatch: expected " << DataTypeTraits::name << ", got " - << data_type << std::endl; - return -1; // Or handle the error appropriately - } - - ck_tile::index_t batch_size = arg_parser.get_int("b"); - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); - - ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); - - gemm_basic_args args; - args.p_a = a_buf.GetDeviceBuffer(); - args.p_b = b_buf.GetDeviceBuffer(); - args.p_c = c_buf.GetDeviceBuffer(); - args.kbatch = batch_size; - args.M = M; - args.N = N; - args.K = K; - - auto f_get_default_stride = [](std::size_t row, - std::size_t col, - std::size_t stride, - auto layout) { - if(stride == 0) - { - // give a chance if stride is zero, return a default packed stride - if constexpr(std::is_same_v) - { - return col; - } - else - { - return row; - } - } - else - return stride; - }; - - args.stride_A = f_get_default_stride(M, K, stride_a, LayoutA{}); - args.stride_B = f_get_default_stride(K, N, stride_b, LayoutB{}); - args.stride_C = f_get_default_stride(M, N, stride_c, LayoutC{}); - - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true}); - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "The overall perfomance of the GEMM with " - << "[" << data_type << "]" - << "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K - << " is: \n"; - std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n" - << std::flush; - - return ave_time; -} - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); - - ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - - using ALayout = ck_tile::tensor_layout::gemm::RowMajor; - using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; - using CLayout = ck_tile::tensor_layout::gemm::RowMajor; - - using namespace ck_tile::literals; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - auto f_get_default_stride = [](std::size_t row, - std::size_t col, - std::size_t stride, - auto layout) { - if(stride == 0) - { - // give a chance if stride is zero, return a default packed stride - if constexpr(std::is_same_v) - { - return col; - } - else - { - return row; - } - } - else - return stride; - }; - - stride_A = f_get_default_stride(M, K, stride_A, ALayout{}); - stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); - stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); - - ck_tile::HostTensor a_host(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); - ck_tile::HostTensor b_host(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); - - ck_tile::HostTensor c_host_ref(f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - ck_tile::HostTensor c_host_dev(f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); - - ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); - - a_buf.ToDevice(a_host.data()); - b_buf.ToDevice(b_host.data()); - - // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadA = true; - constexpr bool kPadB = true; - constexpr bool kPadC = true; - // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; @@ -263,51 +56,40 @@ int main(int argc, char* argv[]) using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1; - invoke_gemm(a_buf, b_buf, c_buf, arg_parser); + using TilePartitioner = ck_tile::GemmTilePartitioner; + using GemmEpilogue = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem>; + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; - c_buf.FromDevice(c_host_dev.data()); + auto kargs = Kernel::MakeKargs(args.p_a, + args.p_b, + args.p_c, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + args.stride_C); - bool pass_cpu = true; + const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + constexpr dim3 blocks = Kernel::BlockSize(); - if(arg_parser.get_int("v") == 1) + if(s.log_level_ > 0) { - // ToDo: Will Add the Element Op (bias) verification in the future. - ck_tile::reference_gemm( - a_host, b_host, c_host_ref); - - pass_cpu = ck_tile::check_err(c_host_dev, c_host_ref); - - std::cout << "The CPU verification result is:" << (pass_cpu ? "correct" : "fail") - << std::flush; + std::cout << "Lunching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } - bool pass_gpu = true; + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - if(arg_parser.get_int("v") == 2) - { - ck_tile::HostTensor c_host_gpu_ref( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); - c_gpu_buf.SetZero(); - - ck_tile::reference_gemm_gpu( - a_buf, b_buf, c_gpu_buf, M, N, K, stride_A, stride_B, stride_C); - - c_gpu_buf.FromDevice(c_host_gpu_ref.data()); - - pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref); - - std::cout << "The GPU verification result is: " << (pass_gpu ? "correct" : "fail") - << std::flush; - } - - std::cout << std::endl << std::flush; - - return !pass_gpu; + return ave_time; } + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index e3b5f86f7b..b3b79dc1cd 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -65,5 +65,25 @@ struct gemm_basic_args ck_tile::index_t stride_C; }; +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("b", "1", "batch size") + .insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "4096", "k dimension") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + // host API float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp index 85a850ffad..9ca83fed10 100644 --- a/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp +++ b/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp @@ -1,4 +1,3 @@ - // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. @@ -15,26 +14,6 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("b", "1", "batch size") - .insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "4096", "k dimension") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("warmup", "50", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - template float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) { @@ -206,175 +185,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) return ave_time; } -template -float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - ck_tile::index_t kbatch, - int n_warmup, - int n_repeat) -{ - gemm_basic_args args; - args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); - args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); - args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); - args.kbatch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_C = stride_C; +#include "run_gemm_example.inc" - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - - std::string op_name{"Gemm{MemBoundPipeline}"}; - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K - << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; - - return ave_time; -} - -int main(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); - - ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - - ck_tile::index_t batch_size = arg_parser.get_int("b"); - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); - - using ALayout = ck_tile::tensor_layout::gemm::RowMajor; - using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; - using CLayout = ck_tile::tensor_layout::gemm::RowMajor; - - using namespace ck_tile::literals; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - auto f_get_default_stride = [](std::size_t row, - std::size_t col, - std::size_t stride, - auto layout) { - if(stride == 0) - { - // give a chance if stride is zero, return a default packed stride - if constexpr(std::is_same_v) - { - return col; - } - else - { - return row; - } - } - else - return stride; - }; - - stride_A = f_get_default_stride(M, K, stride_A, ALayout{}); - stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); - stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); - - ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); - ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); - ck_tile::HostTensor c_m_n_dev_result( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - - // TODO: add different init types - - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); - - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - - a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); - - invoke_gemm(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - batch_size, - n_warmup, - n_repeat); - - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - bool pass = true; - - if(arg_parser.get_int("v") == 1) - { - ck_tile::HostTensor c_m_n_host_ref( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - c_m_n_host_ref.SetZero(); - - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_ref); - - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); - - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; - } - else if(arg_parser.get_int("v") == 2) - { - ck_tile::HostTensor c_m_n_gpu_ref( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); - ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); - c_m_n_gpu_ref.SetZero(); - c_m_n_gpu_buf_ref.SetZero(); - - ck_tile::reference_gemm_gpu( - a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); - - c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); - - std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; - } - - return pass; -} +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc new file mode 100644 index 0000000000..290af0f962 --- /dev/null +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + gemm_basic_args args; + args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); + args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); + args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); + args.kbatch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + + float ave_time = gemm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Gemm{MemBoundPipeline}"}; + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +int run_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t batch_size = arg_parser.get_int("b"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + using ALayout = ck_tile::tensor_layout::gemm::RowMajor; + using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayout = ck_tile::tensor_layout::gemm::RowMajor; + + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, ALayout{}); + stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); + stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + + // TODO: add different init types + + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_size, + n_warmup, + n_repeat); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); + + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + ck_tile::HostTensor c_m_n_gpu_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ck_tile::reference_gemm_gpu( + a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); + + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +}