Files
composable_kernel/codegen/test/gemm_multiple_d.cpp
Illia Silin 68a08c872e Rebase the PR #1520 to ROCm repo. (#1574)
* Implement hiprtc for codegen tests

* Introduce gemm_softmax_gemm to codegen.

* Fix codegen build issues.

* Address PR comments.

* Separate ck_host lib and gemm_softmax_gemm into different PR.

* Fix cmake.

* Replace ENV variable with CMake option for toggling hipRTC in codegen
tests.

* Address PR comments.

* fix clang format

* Add missing header in magic_division.hpp

* - Workaround for hipRTC content wrapper
- Move descriptor for gemm_softmax_gemm to different branch

* Fix formatting.

* Revert "Fix formatting."

This reverts commit b5209eaef4.

* formatting fix

* fixed header guard issues

* updated header guards

* updated data_type for new types

* fixed redefinition error

* Add codegen test for batched_gemm_softmax_gemm.

Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com>

* formatting fix

---------

Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com>
Co-authored-by: Dino Musić <dino.music@htecgroup.com>
Co-authored-by: Mirza Halilcevic <mirza.halilcevic@htecgroup.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: arai713 <67439843+arai713@users.noreply.github.com>
Co-authored-by: Astha Rai <astha.rai713@gmail.com>
Co-authored-by: Mirza Halilcevic <mirza.halilcevic@amd.com>
2025-02-20 18:58:14 -08:00

88 lines
3.4 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "common.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <algorithm>
#include <cmath>
#include <fstream>
#include <iterator>
#include <random>
using half = _Float16;
const std::string gemm_compile_check = R"__ck__(
#include <${include}>
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
using G = ${template};
constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
static_assert(desc.IsValid(), "Invalid ck gemm.");
if constexpr(desc.IsValid())
{
${template}::Run(desc,
a,
b,
ck::make_tuple(),
c);
}
}
)__ck__";
TEST_CASE(test_problem_kernel)
{
ck::host::device_gemm_multiple_d::Problem prob;
prob.M = 1024;
prob.N = 1024;
prob.K = 1024;
check_all<half> check;
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 2));
std::string epilogue = "";
std::string prologue = "";
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
std::cout << "Num solutions: " << solutions.size() << std::endl;
for(auto i = 0; i < solutions.size(); ++i)
{
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
auto&& solution = solutions[i];
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
options.kernel_name = "f";
auto k = rtc::compile_kernel(srcs, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
ck::host::integer_divide_ceil(prob.N, n_per_block);
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
CHECK(report(solution, check(rtc::from_gpu(c))));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }