mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Introduce gemm_softmax_gemm to codegen.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
#include "ck/host/device_gemm_multiple_d/problem.hpp"
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
@@ -15,13 +17,59 @@
|
||||
using half = _Float16;
|
||||
// using half = __fp16;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
const char* const disable_warning_pragma = R"__migraphx__(
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
${content}
|
||||
#pragma clang diagnostic pop
|
||||
)__migraphx__";
|
||||
|
||||
template <class P>
|
||||
std::string ck_disable_warnings(P p)
|
||||
{
|
||||
return ck::host::InterpolateString(disable_warning_pragma,
|
||||
{{"content", std::string{p.data(), p.size()}}});
|
||||
}
|
||||
|
||||
static std::unordered_map<std::string, std::string> create_ck_header_strings()
|
||||
{
|
||||
std::unordered_map<std::string, std::string> result;
|
||||
auto ck_headers = ck::host::GetHeaders();
|
||||
|
||||
std::transform(
|
||||
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
|
||||
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::vector<rtc::src_file> create_ck_headers()
|
||||
{
|
||||
static const auto& header_strings = create_ck_header_strings();
|
||||
std::vector<rtc::src_file> srcs;
|
||||
std::transform(
|
||||
header_strings.begin(), header_strings.end(), std::back_inserter(srcs), [&](auto& p) -> rtc::src_file {
|
||||
std::string sec(p.second.begin(), p.second.end());
|
||||
return {p.first, sec};
|
||||
});
|
||||
return srcs;
|
||||
}
|
||||
|
||||
static inline const std::vector<rtc::src_file>& ck_headers()
|
||||
{
|
||||
static const auto& headers = create_ck_headers();
|
||||
return headers;
|
||||
}
|
||||
|
||||
std::vector<rtc::src_file> get_headers_for_test()
|
||||
{
|
||||
std::vector<rtc::src_file> result;
|
||||
auto hs = ck::host::GetHeaders();
|
||||
std::transform(
|
||||
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
|
||||
return {p.first, p.second};
|
||||
std::string sec(p.second.begin(), p.second.end());
|
||||
return {p.first, sec};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
@@ -130,10 +178,13 @@ const std::string gemm_compile_check = R"__ck__(
|
||||
|
||||
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
|
||||
using G = ${template};
|
||||
constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
|
||||
ck::make_tuple(),
|
||||
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
|
||||
constexpr auto desc =
|
||||
G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
|
||||
${k})),
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(${n},
|
||||
${k}), ck::make_tuple(1, ${n})), ck::make_tuple(),
|
||||
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
|
||||
${n})));
|
||||
|
||||
static_assert(desc.IsValid(), "Invalid ck gemm.");
|
||||
|
||||
@@ -163,23 +214,32 @@ TEST_CASE(test_problem_kernel)
|
||||
std::string epilogue = "";
|
||||
std::string prologue = "";
|
||||
|
||||
for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
|
||||
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
|
||||
std::cout << "Num solutions: " << solutions.size() << std::endl;
|
||||
for(auto i = 0; i < solutions.size(); ++i)
|
||||
{
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
// auto srcs = get_headers_for_test();
|
||||
// srcs.push_back({"main.cpp", src});
|
||||
// rtc::compile_options options;
|
||||
// options.kernel_name = "f";
|
||||
rtc::hip_compile_options options;
|
||||
options.kernel_name = "f";
|
||||
auto k = rtc::compile_kernel(srcs, options);
|
||||
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
|
||||
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
|
||||
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
|
||||
options.additional_src_files = ck_headers();
|
||||
// auto k = rtc::compile_kernel(srcs, options);
|
||||
std::cout << src << std::endl;
|
||||
auto k = rtc::compile_hip_code_object(src, options);
|
||||
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
|
||||
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
|
||||
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
|
||||
ck::host::integer_divide_ceil(prob.N, n_per_block);
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
|
||||
|
||||
@@ -187,4 +247,34 @@ TEST_CASE(test_problem_kernel)
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE(test_gemm_softmax_gemm)
|
||||
{
|
||||
ck::host::device_batched_gemm_softmax_gemm::Problem prob;
|
||||
prob.TransA = false;
|
||||
prob.TransB = true;
|
||||
prob.TransB1 = false;
|
||||
prob.TransC = false;
|
||||
prob.M = 1024;
|
||||
prob.N = 1024;
|
||||
prob.K = 1024;
|
||||
prob.O = 1024;
|
||||
check_all<half> check;
|
||||
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
|
||||
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
|
||||
auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
|
||||
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 3));
|
||||
|
||||
std::string epilogue = "";
|
||||
std::string prologue = "";
|
||||
|
||||
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
|
||||
std::cout << "Num solutions: " << solutions.size() << std::endl;
|
||||
|
||||
for(auto i = 0; i < solutions.size(); ++i) {
|
||||
std::cout << "Solution " << i << std::endl;
|
||||
std::cout << solutions[i].ToTemplateString() << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) { test::run(argc, argv); }
|
||||
|
||||
Reference in New Issue
Block a user