Add default arguments for prologue and epilogue. (#2020)

This commit is contained in:
Mirza Halilčević
2025-03-26 17:28:40 +01:00
committed by GitHub
parent 99b2bbc1d6
commit 21e0ca197d
4 changed files with 17 additions and 23 deletions

View File

@@ -39,8 +39,8 @@ struct Problem
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
const std::string& prologue = "",
const std::string& epilogue = "") const;
};
} // namespace device_batched_gemm_softmax_gemm

View File

@@ -37,8 +37,8 @@ struct Problem
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
const std::string& prologue = "",
const std::string& epilogue = "") const;
};
} // namespace device_gemm_multiple_d

View File

@@ -48,22 +48,19 @@ TEST_CASE(test_problem_kernel)
auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 3));
std::string epilogue = "";
std::string prologue = "";
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
auto solutions = prob.GetSolutions("gfx90a");
std::cout << "Num solutions: " << solutions.size() << std::endl;
for(auto i = 0; i < solutions.size(); ++i)
{
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
auto&& solution = solutions[i];
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)},
{"o", std::to_string(prob.O)}});
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)},
{"o", std::to_string(prob.O)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;

View File

@@ -53,21 +53,18 @@ TEST_CASE(test_problem_kernel)
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 2));
std::string epilogue = "";
std::string prologue = "";
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
auto solutions = prob.GetSolutions("gfx90a");
std::cout << "Num solutions: " << solutions.size() << std::endl;
for(auto i = 0; i < solutions.size(); ++i)
{
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
auto&& solution = solutions[i];
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;