From 97b32a1f18327226556b80863e066944b4e00767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mirza=20Halil=C4=8Devi=C4=87?= <109971222+mirza-halilcevic@users.noreply.github.com> Date: Wed, 26 Mar 2025 17:28:40 +0100 Subject: [PATCH] Add default arguments for prologue and epilogue. (#2020) [ROCm/composable_kernel commit: 21e0ca197de46062ee72f4ed773696a4f266aa9f] --- .../problem.hpp | 4 ++-- .../ck/host/device_gemm_multiple_d/problem.hpp | 4 ++-- codegen/test/batched_gemm_softmax_gemm.cpp | 17 +++++++---------- codegen/test/gemm_multiple_d.cpp | 15 ++++++--------- 4 files changed, 17 insertions(+), 23 deletions(-) diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp index 8e68f6cc88..30dd1487ca 100644 --- a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp @@ -39,8 +39,8 @@ struct Problem // returns a list of instances based on the problem spec and provided fusion operations std::vector GetSolutions(const std::string& arch, - const std::string& prologue, - const std::string& epilogue) const; + const std::string& prologue = "", + const std::string& epilogue = "") const; }; } // namespace device_batched_gemm_softmax_gemm diff --git a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp index f4036328ec..1c65fb71ff 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp @@ -37,8 +37,8 @@ struct Problem // returns a list of instances based on the problem spec and provided fusion operations std::vector GetSolutions(const std::string& arch, - const std::string& prologue, - const std::string& epilogue) const; + const std::string& prologue = "", + const std::string& epilogue = "") const; }; } // namespace device_gemm_multiple_d diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp index 0de8dbdd51..98e78fc148 100644 --- a/codegen/test/batched_gemm_softmax_gemm.cpp +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -48,22 +48,19 @@ TEST_CASE(test_problem_kernel) auto b1 = to_gpu(generate_buffer(1024 * 1024, 2)); auto c = to_gpu(generate_buffer(1024 * 1024, 3)); - std::string epilogue = ""; - std::string prologue = ""; - - auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue); + auto solutions = prob.GetSolutions("gfx90a"); std::cout << "Num solutions: " << solutions.size() << std::endl; for(auto i = 0; i < solutions.size(); ++i) { std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}, - {"o", std::to_string(prob.O)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}, + {"o", std::to_string(prob.O)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index 2a383fc1c8..dd908e8b58 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -53,21 +53,18 @@ TEST_CASE(test_problem_kernel) auto b = to_gpu(generate_buffer(1024 * 1024, 1)); auto c = to_gpu(generate_buffer(1024 * 1024, 2)); - std::string epilogue = ""; - std::string prologue = ""; - - auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue); + auto solutions = prob.GetSolutions("gfx90a"); std::cout << "Num solutions: " << solutions.size() << std::endl; for(auto i = 0; i < solutions.size(); ++i) { std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options;