mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Merge branch 'ck_mgx_temp' into ck_migraphx_integration
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#include "common.hpp"
|
||||
#include "ck/host/device_gemm_multiple_d/problem.hpp"
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
@@ -21,10 +23,13 @@ const std::string gemm_compile_check = R"__ck__(
|
||||
|
||||
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
|
||||
using G = ${template};
|
||||
constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
|
||||
ck::make_tuple(),
|
||||
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
|
||||
constexpr auto desc =
|
||||
G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
|
||||
${k})),
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(${n},
|
||||
${k}), ck::make_tuple(1, ${n})), ck::make_tuple(),
|
||||
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
|
||||
${n})));
|
||||
|
||||
static_assert(desc.IsValid(), "Invalid ck gemm.");
|
||||
|
||||
@@ -56,7 +61,6 @@ TEST_CASE(test_problem_kernel)
|
||||
|
||||
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
|
||||
std::cout << "Num solutions: " << solutions.size() << std::endl;
|
||||
|
||||
for(auto i = 0; i < solutions.size(); ++i)
|
||||
{
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
@@ -83,4 +87,34 @@ TEST_CASE(test_problem_kernel)
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE(test_gemm_softmax_gemm)
|
||||
{
|
||||
ck::host::device_batched_gemm_softmax_gemm::Problem prob;
|
||||
prob.TransA = false;
|
||||
prob.TransB = true;
|
||||
prob.TransB1 = false;
|
||||
prob.TransC = false;
|
||||
prob.M = 1024;
|
||||
prob.N = 1024;
|
||||
prob.K = 1024;
|
||||
prob.O = 1024;
|
||||
check_all<half> check;
|
||||
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
|
||||
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
|
||||
auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
|
||||
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 3));
|
||||
|
||||
std::string epilogue = "";
|
||||
std::string prologue = "";
|
||||
|
||||
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
|
||||
std::cout << "Num solutions: " << solutions.size() << std::endl;
|
||||
|
||||
for(auto i = 0; i < solutions.size(); ++i) {
|
||||
std::cout << "Solution " << i << std::endl;
|
||||
std::cout << solutions[i].ToTemplateString() << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) { test::run(argc, argv); }
|
||||
|
||||
Reference in New Issue
Block a user