From d2873859333f0a7b71f8b784cb0d3166e3473efd Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Fri, 12 Dec 2025 17:08:26 -0700 Subject: [PATCH] [CK TILE GEMM STREAMK] update identifier names according to the new code style (#3348) * [CK TILE GEMM STREAMK] update identifier names according to the new code style [ROCm/composable_kernel commit: 9707ddb444f42b490c73b7884babccde2988ed7e] --- .../ck_tile/40_streamk_gemm/gemm_utils.hpp | 56 +-- .../40_streamk_gemm/run_gemm_example.inc | 380 +++++++++--------- .../40_streamk_gemm/streamk_gemm_basic.cpp | 204 +++++----- 3 files changed, 328 insertions(+), 312 deletions(-) diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp index dad31ec637..34c6c6b0ae 100644 --- a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -7,46 +7,46 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -struct GemmConfigBase +struct GemmConfigurationBase { - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; - static constexpr bool kPadK = true; + static constexpr bool PAD_M = true; + static constexpr bool PAD_N = true; + static constexpr bool PAD_K = true; - static constexpr bool PermuteA = false; - static constexpr bool PermuteB = false; + static constexpr bool PERMUTE_A = false; + static constexpr bool PERMUTE_B = false; - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; + static constexpr bool TRANSPOSE_C = false; + static constexpr bool USE_STRUCTURED_SPARSITY = false; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool Preshuffle = false; - static constexpr bool DoubleSmemBuffer = false; + static constexpr int BLOCK_PER_CU = 1; + static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1; + static constexpr bool PRESHUFFLE = false; + static constexpr bool DOUBLE_SMEM_BUFFER = false; }; -template -struct GemmConfigMemoryInterwave : public GemmConfigBase +template +struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 16; + static constexpr ck_tile::index_t M_TILE = 256; + static constexpr ck_tile::index_t N_TILE = 256; + static constexpr ck_tile::index_t K_TILE = 16; - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; + static constexpr ck_tile::index_t M_WARP = 2; + static constexpr ck_tile::index_t N_WARP = 2; + static constexpr ck_tile::index_t K_WARP = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + static constexpr ck_tile::index_t M_WARP_TILE = 32; + static constexpr ck_tile::index_t N_WARP_TILE = 32; + static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16; - static constexpr bool Persistent = Persistent_; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool PERSISTENT = IsPersistent; + static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave; }; template -struct StreamKGemmTypeConfig +struct StreamKGemmTypeConfiguration { using ADataType = ADataType_; using BDataType = BDataType_; @@ -54,7 +54,7 @@ struct StreamKGemmTypeConfig using CDataType = CDataType_; }; -auto create_args(int argc, char* argv[]) +auto createArgs(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "512", "m dimension") diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index d18ac2e68a..7442bd33f2 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -12,31 +12,35 @@ static constexpr inline auto is_row_major(Layout) } template -auto calculate_rtol_atol(const ck_tile::index_t K, - const ck_tile::index_t kbatch, - const float max_accumulated_value) +auto calculateRtolAtol(const ck_tile::index_t k_dim, + const ck_tile::index_t k_batch, + const float max_accumulated_value) { using ComputeType = std::conditional_t; // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( - ck_tile::integer_divide_ceil(K, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + const auto relative_tolerance = + ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(k_dim, k_batch)); + const auto absolute_tolerance = + ck_tile::get_absolute_threshold( + max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch)); // Calculate error due to multiple WGs working in the same C macro tile - const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = ck_tile::get_absolute_threshold( - max_accumulated_value, kbatch); + const auto relative_tolerance_split_k = + ck_tile::get_relative_threshold(k_batch); + const auto absolute_tolerance_split_k = + ck_tile::get_absolute_threshold(max_accumulated_value, + k_batch); // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k), + std::max(absolute_tolerance, absolute_tolerance_split_k)); } -template std::tuple gemm(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s); + const ck_tile::stream_config& stream_config); -template -std::tuple -invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - int n_warmup, - int n_repeat, - bool flush_cache, - ck_tile::StreamKReductionStrategy reduction_strategy) +std::tuple invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory, + ck_tile::DeviceMem& b_k_n_device_memory, + ck_tile::DeviceMem& c_m_n_device_memory, + ck_tile::index_t m_dim, + ck_tile::index_t n_dim, + ck_tile::index_t k_dim, + ck_tile::index_t stride_a, + ck_tile::index_t stride_b, + ck_tile::index_t stride_c, + int warmup_iterations, + int repeat_iterations, + bool flush_cache, + ck_tile::StreamKReductionStrategy reduction_strategy) { - ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - c_m_n_dev_buf.GetDeviceBuffer(), - M, - N, - K, - stride_A, - stride_B, - stride_C}; + ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(), + b_k_n_device_memory.GetDeviceBuffer(), + c_m_n_device_memory.GetDeviceBuffer(), + m_dim, + n_dim, + k_dim, + stride_a, + stride_b, + stride_c}; - std::tuple ave_time_and_batch; + std::tuple average_time_and_batch; if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) { - ave_time_and_batch = gemm( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + average_time_and_batch = gemm( + args, + ck_tile::stream_config{ + nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache}); } else /*Reduction*/ { - ave_time_and_batch = gemm( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + average_time_and_batch = gemm( + args, + ck_tile::stream_config{ + nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache}); } - return ave_time_and_batch; + return average_time_and_batch; } template -bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, - const ck_tile::HostTensor& c_m_n_ref, - const ck_tile::tuple& rtol_atol, - const char* variant) +bool doVerify(const ck_tile::HostTensor& c_m_n_device_result, + const ck_tile::HostTensor& c_m_n_reference, + const ck_tile::tuple& relative_absolute_tolerances, + const char* variant) { - bool pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_ref, + bool pass = ck_tile::check_err(c_m_n_device_result, + c_m_n_reference, "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); + relative_absolute_tolerances.at(ck_tile::number<0>{}), + relative_absolute_tolerances.at(ck_tile::number<1>{})); - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "Relative error threshold: " + << relative_absolute_tolerances.at(ck_tile::number<0>{}) + << " Absolute error threshold: " + << relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl; std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail") << std::endl; return pass; } -ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy) +ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy) { if(strategy == "atomic") { @@ -156,172 +165,169 @@ ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string } } -template -int run_gemm_example_with_layouts(int argc, - char* argv[], - const ALayout a_layout = ALayout{}, - const BLayout b_layout = BLayout{}, - [[maybe_unused]] const CLayout c_layout = CLayout{}) +int runGemmExampleWithLayouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) { - auto [result, arg_parser] = create_args(argc, argv); + auto [result, arg_parser] = createArgs(argc, argv); if(!result) return -1; - static_assert(!GemmConfig::Preshuffle, "Not implemented"); - static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented"); - static_assert(!GemmConfig::PermuteA, "Not implemented"); - static_assert(!GemmConfig::PermuteB, "Not implemented"); + static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented"); + static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented"); + static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented"); + static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented"); - using ADataType = typename TypeConfig::ADataType; - using BDataType = typename TypeConfig::BDataType; - using AccDataType = typename TypeConfig::AccDataType; - using CDataType = typename TypeConfig::CDataType; + using ADataType = typename TypeConfiguration::ADataType; + using BDataType = typename TypeConfiguration::BDataType; + using AccumulatorDataType = typename TypeConfiguration::AccDataType; + using CDataType = typename TypeConfiguration::CDataType; - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); + ck_tile::index_t m_dim = arg_parser.get_int("m"); + ck_tile::index_t n_dim = arg_parser.get_int("n"); + ck_tile::index_t k_dim = arg_parser.get_int("k"); - ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); + int warmup_iterations = arg_parser.get_int("warmup"); + int repeat_iterations = arg_parser.get_int("repeat"); ck_tile::index_t init_method = arg_parser.get_int("init"); bool flush_cache = arg_parser.get_bool("flush_cache"); - ck_tile::StreamKReductionStrategy reduction_strategy = - get_reduction_strategy_value(arg_parser.get_str("reduction_strategy")); + getReductionStrategyValue(arg_parser.get_str("reduction_strategy")); - stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); - stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout)); + stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout)); + stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{})); - ck_tile::HostTensor a_m_k( - ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); - ck_tile::HostTensor c_m_n_dev_result( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::HostTensor a_m_k_host( + ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n_host( + ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_device_result( + ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{}))); if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_host); } else if(init_method == 1) { - ck_tile::FillMonotonicSeq{}(a_m_k); - ck_tile::FillMonotonicSeq{}(b_k_n); + ck_tile::FillMonotonicSeq{}(a_m_k_host); + ck_tile::FillMonotonicSeq{}(b_k_n_host); } else if(init_method == 2) { - ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_host); } else { - a_m_k.SetZero(); - b_k_n.SetZero(); + a_m_k_host.SetZero(); + b_k_n_host.SetZero(); } - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes()); - a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); + a_m_k_device_memory.ToDevice(a_m_k_host.data()); + b_k_n_device_memory.ToDevice(b_k_n_host.data()); + c_m_n_device_memory.SetZero(); + c_m_n_device_result.SetZero(); + auto [average_time, num_wgs_per_tile] = invokeGemm, + AccumulatorDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_device_memory, + b_k_n_device_memory, + c_m_n_device_memory, + m_dim, + n_dim, + k_dim, + stride_a, + stride_b, + stride_c, + warmup_iterations, + repeat_iterations, + flush_cache, + reduction_strategy); - auto [ave_time, num_wgs_per_tile] = invoke_gemm, - AccDataType, - CDataType, - ALayout, - BLayout, - ck_tile::tuple<>, - CLayout>(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - n_warmup, - n_repeat, - flush_cache, - reduction_strategy); + c_m_n_device_memory.FromDevice(c_m_n_device_result.data()); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K - << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim; + std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim + + sizeof(CDataType) * m_dim * n_dim; + float tflops = static_cast(flop) / 1.E9 / average_time; + float gb_per_sec = num_byte / 1.E6 / average_time; + std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim + << " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c << " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name << " C_Layout=" << CLayout::name << " A_Type=" << ck_tile::DataTypeTraits::name << " B_Type=" << ck_tile::DataTypeTraits::name << " C_Type=" << ck_tile::DataTypeTraits::name << " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " " - << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time + << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; - bool pass = false; // Memory on host to store gpu reference result - ck_tile::HostTensor c_m_n_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_m_n_ref.SetZero(); + ck_tile::HostTensor c_m_n_reference( + ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{}))); + c_m_n_reference.SetZero(); if(arg_parser.get_int("v") == 1) // Validate on the CPU { - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_ref); + ck_tile::reference_gemm( + a_m_k_host, b_k_n_host, c_m_n_reference); const float max_accumulated_value = - *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, num_wgs_per_tile, max_accumulated_value); - pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU"); + *std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end()); + const auto relative_absolute_tolerances = + calculateRtolAtol( + k_dim, num_wgs_per_tile, max_accumulated_value); + pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU"); } else if(arg_parser.get_int("v") == 2) // Validate on the GPU { // Memory on device to store gpu reference result - ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes()); - c_m_n_gpu_buf_ref.SetZero(); - - ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); - BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); - CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + ck_tile::DeviceMem c_m_n_gpu_buffer_reference( + c_m_n_reference.get_element_space_size_in_bytes()); + c_m_n_gpu_buffer_reference.SetZero(); + ADataType* d_A = static_cast(a_m_k_device_memory.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_device_memory.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buffer_reference.GetDeviceBuffer()); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - - c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); + CLayout>( + d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c); + c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data()); const float max_accumulated_value = - *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, num_wgs_per_tile, max_accumulated_value); - pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU"); + *std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end()); + const auto relative_absolute_tolerances = + calculateRtolAtol( + k_dim, num_wgs_per_tile, max_accumulated_value); + pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU"); } return pass; diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index 83795fbf6a..d3ee9fe9c6 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -4,11 +4,11 @@ #include "gemm_utils.hpp" #include "ck_tile/ops/common.hpp" -template std::tuple gemm(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s) + const ck_tile::stream_config& stream_config) { - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile:: - sequence, - GemmConfig::PermuteA, - GemmConfig::PermuteB>; + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + GemmConfiguration::PERMUTE_A, + GemmConfiguration::PERMUTE_B>; - using TilePartitioner = - ck_tile::StreamKTilePartitioner; + using TilePartitioner = ck_tile:: + StreamKTilePartitioner; - using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; - const auto Run = [&](const auto memory_operation) -> std::tuple { + const auto runKernel = [&](const auto memory_operation) -> std::tuple { // We create the GEMM pipeline without specifying has_hot_loop or tail_num. // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -61,39 +67,39 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, ck_tile::CShuffleEpilogueProblem>; + GemmConfiguration::NUM_WAVE_GROUPS>>; using Kernel = ck_tile::StreamKKernel; - auto kargs = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); + auto kernel_args = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); ck_tile::DeviceMem workspace_data(workspace_size); workspace_data.SetZero(); - kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); + kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); - dim3 grids = Kernel::GridSize(kargs.tile_partitioner); + dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) + if(!Kernel::IsSupportedArgument(kernel_args)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - if(s.log_level_ > 0) + if(stream_config.log_level_ > 0) { std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' << "shape: " << GemmShape::GetName() << '\n' @@ -109,7 +115,7 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, { // Clear the output C tensor results after each repetition of the kernel hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); } else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { @@ -120,45 +126,47 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, std::function preprocess = reset_data_buffers; - float ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + float average_time = + ck_tile::launch_kernel_time_mask(stream_config, + preprocess, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kernel_args)); - ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); - return std::tuple{ave_time, num_wgs_per_tile}; + ck_tile::index_t num_wgs_per_tile = + kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); + return std::tuple{average_time, num_wgs_per_tile}; }; if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) { - return Run(ck_tile::integral_constant{}); + return runKernel(ck_tile::integral_constant{}); } else // We are using ck_tile::StreamKReductionStrategy::Reduction { - return Run(ck_tile::integral_constant{}); + return runKernel(ck_tile::integral_constant{}); } } #include "run_gemm_example.inc" -template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +template +int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[]) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( + return runGemmExampleWithLayouts( argc, argv, Row{}, Col{}, Row{}); } else @@ -169,72 +177,74 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a return 0; } -template