Rebase the PR #1520 to ROCm repo. (#1574)

* Implement hiprtc for codegen tests

* Introduce gemm_softmax_gemm to codegen.

* Fix codegen build issues.

* Address PR comments.

* Separate ck_host lib and gemm_softmax_gemm into different PR.

* Fix cmake.

* Replace ENV variable with CMake option for toggling hipRTC in codegen
tests.

* Address PR comments.

* fix clang format

* Add missing header in magic_division.hpp

* - Workaround for hipRTC content wrapper
- Move descriptor for gemm_softmax_gemm to different branch

* Fix formatting.

* Revert "Fix formatting."

This reverts commit b5209eaef4.

* formatting fix

* fixed header guard issues

* updated header guards

* updated data_type for new types

* fixed redefinition error

* Add codegen test for batched_gemm_softmax_gemm.

Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com>

* formatting fix

---------

Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com>
Co-authored-by: Dino Musić <dino.music@htecgroup.com>
Co-authored-by: Mirza Halilcevic <mirza.halilcevic@htecgroup.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: arai713 <67439843+arai713@users.noreply.github.com>
Co-authored-by: Astha Rai <astha.rai713@gmail.com>
Co-authored-by: Mirza Halilcevic <mirza.halilcevic@amd.com>

[ROCm/composable_kernel commit: 68a08c872e]
This commit is contained in:
Illia Silin
2025-02-20 18:58:14 -08:00
committed by GitHub
parent bd738ccf57
commit ff04241799
32 changed files with 880 additions and 517 deletions

View File

@@ -0,0 +1,87 @@
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "common.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <cmath>
using half = _Float16;
const std::string gemm_compile_check = R"__ck__(
#include <${include}>
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, const ck::half_t* b1, ck::half_t* c) {
using G = ${template};
constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${k}), ck::make_tuple(${m}, 1)),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(${n}, 1)),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${o}), ck::make_tuple(1, ${n})),
ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${o}), ck::make_tuple(${m}, 1)));
static_assert(desc.IsValid(), "Invalid ck gemm.");
if constexpr(desc.IsValid())
{
${template}::Run(desc,
1.0,
a,
b,
b1,
c);
}
}
)__ck__";
TEST_CASE(test_problem_kernel)
{
ck::host::device_batched_gemm_softmax_gemm::Problem prob;
prob.M = 1024;
prob.N = 1024;
prob.K = 1024;
prob.O = 1024;
prob.TransB = true;
check_all<half> check1, check2;
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 3));
std::string epilogue = "";
std::string prologue = "";
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
std::cout << "Num solutions: " << solutions.size() << std::endl;
for(auto i = 0; i < solutions.size(); ++i)
{
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
auto&& solution = solutions[i];
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)},
{"o", std::to_string(prob.O)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
options.kernel_name = "f";
auto k = rtc::compile_kernel(srcs, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("Gemm01MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("Gemm1NPerBlock");
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
ck::host::integer_divide_ceil(prob.N, n_per_block);
k.launch(nullptr, grid_size * block_size, block_size)(
a.data(), b.data(), b1.data(), c.data());
if(solution.GetTemplateParameter<bool>("MaskOutUpperTriangle"))
CHECK(report(solution, check1(rtc::from_gpu(c))));
else
CHECK(report(solution, check2(rtc::from_gpu(c))));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }

View File

@@ -6,134 +6,24 @@
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
#include <cmath>
#include <iterator>
#include <random>
#include <test.hpp>
#include "common.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <algorithm>
#include <cmath>
#include <fstream>
#include <iterator>
#include <random>
using half = _Float16;
// using half = __fp16;
std::vector<rtc::src_file> get_headers_for_test()
{
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, p.second};
});
return result;
}
template <class T>
rtc::buffer<T> generate_buffer(std::size_t n, std::size_t seed = 0)
{
rtc::buffer<T> result(n);
std::mt19937 gen(seed);
std::uniform_real_distribution<double> dis(-1.0);
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
return result;
}
template <class T, class U>
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
{
return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) {
return fabs(x - y) < atol + rtol * fabs(y);
});
}
std::string classify(double x)
{
switch(std::fpclassify(x))
{
case FP_INFINITE: return "inf";
case FP_NAN: return "nan";
case FP_NORMAL: return "normal";
case FP_SUBNORMAL: return "subnormal";
case FP_ZERO: return "zero";
default: return "unknown";
}
}
template <class Buffer>
void print_classification(const Buffer& x)
{
std::unordered_set<std::string> result;
for(const auto& i : x)
result.insert(classify(i));
for(const auto& c : result)
std::cout << c << ", ";
std::cout << std::endl;
}
template <class Buffer>
void print_statistics(const Buffer& x)
{
std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", ";
std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", ";
double num_elements = x.size();
auto mean =
std::accumulate(x.begin(), x.end(), double{0.0}, std::plus<double>{}) / num_elements;
auto stddev = std::sqrt(
std::accumulate(x.begin(),
x.end(),
double{0.0},
[&](double r, double v) { return r + std::pow((v - mean), 2.0); }) /
num_elements);
std::cout << "Mean: " << mean << ", ";
std::cout << "StdDev: " << stddev << "\n";
}
template <class Buffer>
void print_preview(const Buffer& x)
{
if(x.size() <= 10)
{
std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; });
}
else
{
std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; });
std::cout << "..., ";
std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; });
}
std::cout << std::endl;
}
template <class T>
struct check_all
{
rtc::buffer<T> data{};
bool operator()(const rtc::buffer<T>& x)
{
if(data.empty())
{
data = x;
return true;
}
if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); }))
return false;
return allclose(data, x);
}
};
template <class Solution>
auto report(const Solution& solution, bool pass)
{
return test::make_predicate(solution.ToTemplateString(), [=] { return pass; });
}
const std::string gemm_compile_check = R"__ck__(
#include <${include}>
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
using G = ${template};
constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
@@ -166,15 +56,19 @@ TEST_CASE(test_problem_kernel)
std::string epilogue = "";
std::string prologue = "";
for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
std::cout << "Num solutions: " << solutions.size() << std::endl;
for(auto i = 0; i < solutions.size(); ++i)
{
auto src = ck::host::InterpolateString(gemm_compile_check,
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
auto&& solution = solutions[i];
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test();
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
options.kernel_name = "f";

View File

@@ -2,27 +2,38 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/host/headers.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <algorithm>
#include <cmath>
#include <iterator>
#include <numeric>
#include <random>
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
#include <unordered_set>
std::vector<rtc::src_file> get_headers_for_test()
inline std::vector<rtc::src_file> create_headers_for_test()
{
auto ck_headers = ck::host::GetHeaders();
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, p.second};
});
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) {
std::string content;
content.reserve(p.second.size() + 1);
content.push_back(' '); // We need a whitespace before the content for hipRTC to work
content.append(p.second.data(), p.second.size());
return rtc::src_file{p.first, std::move(content)};
});
return result;
}
inline const std::vector<rtc::src_file>& get_headers_for_test()
{
static const std::vector<rtc::src_file> headers = create_headers_for_test();
return headers;
}
template <typename V>
std::size_t GetSize(V mLens, V mStrides)
{
@@ -37,18 +48,24 @@ std::size_t GetSize(V mLens, V mStrides)
return space;
}
template <class T, typename V>
rtc::buffer<T> generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
template <class T>
rtc::buffer<T> generate_buffer(std::size_t n, std::size_t seed = 0)
{
std::size_t space = GetSize(mLens, mStrides);
rtc::buffer<T> result(space);
rtc::buffer<T> result(n);
std::mt19937 gen(seed);
std::uniform_real_distribution<double> dis(-1.0);
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
// std::fill(result.begin(), result.end(), 1);
return result;
}
template <class T, typename V>
std::enable_if_t<!std::is_integral_v<V>, rtc::buffer<T>>
generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
{
std::size_t space = GetSize(mLens, mStrides);
return generate_buffer<T>(space, seed);
}
template <class T, class U>
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
{
@@ -57,7 +74,7 @@ bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
});
}
std::string classify(double x)
inline std::string classify(double x)
{
switch(std::fpclassify(x))
{

View File

@@ -4,3 +4,9 @@ add_library(ck_rtc ${RTC_SOURCES})
target_include_directories(ck_rtc PUBLIC include)
target_link_libraries(ck_rtc PUBLIC hip::host)
target_link_libraries(ck_rtc PUBLIC -lstdc++fs)
option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON)
if(USE_HIPRTC_FOR_CODEGEN_TESTS)
target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS)
message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}")
endif()

View File

@@ -12,8 +12,9 @@ namespace rtc {
struct src_file
{
src_file(std::filesystem::path p, std::string c) : path{std::move(p)}, content{std::move(c)} {}
fs::path path;
std::string_view content;
std::string content;
};
struct compile_options
@@ -22,7 +23,7 @@ struct compile_options
std::string kernel_name = "main";
};
kernel compile_kernel(const std::vector<src_file>& src,
kernel compile_kernel(const std::vector<src_file>& srcs,
compile_options options = compile_options{});
} // namespace rtc

View File

@@ -3,14 +3,41 @@
#include <rtc/hip.hpp>
#include <rtc/compile_kernel.hpp>
#ifdef HIPRTC_FOR_CODEGEN_TESTS
#include <hip/hiprtc.h>
#include <rtc/manage_ptr.hpp>
#endif
#include <rtc/tmp_dir.hpp>
#include <stdexcept>
#include <iostream>
#include <fstream>
#include <algorithm>
#include <cassert>
#include <deque>
#include <fstream>
#include <iostream>
#include <numeric>
#include <stdexcept>
namespace rtc {
bool EndsWith(const std::string& value, const std::string& suffix)
{
if(suffix.size() > value.size())
return false;
else
return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin());
}
std::vector<std::string> SplitString(const std::string& s, char delim)
{
std::vector<std::string> elems;
std::stringstream ss(s + delim);
std::string item;
while(std::getline(ss, item, delim))
{
elems.push_back(item);
}
return elems;
}
template <class T>
T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{
@@ -62,7 +89,7 @@ std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device
// TODO: undo after extracting the codeobj
// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; }
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{
assert(not srcs.empty());
tmp_dir td{"compile"};
@@ -103,4 +130,172 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
return kernel{obj.data(), options.kernel_name};
}
#ifdef HIPRTC_FOR_CODEGEN_TESTS
std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
}
void hiprtc_check_error(hiprtcResult err, const std::string& msg = "")
{
if(err != HIPRTC_SUCCESS)
throw std::runtime_error(hiprtc_error(err, msg));
}
struct hiprtc_src_file
{
hiprtc_src_file() = default;
hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
std::string path;
std::string content;
};
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
using hiprtc_program_ptr = RTC_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
template <class... Ts>
hiprtc_program_ptr hiprtc_program_create(Ts... xs)
{
hiprtcProgram prog = nullptr;
auto result = hiprtcCreateProgram(&prog, xs...);
hiprtc_program_ptr p{prog};
hiprtc_check_error(result, "Create program failed.");
return p;
}
struct hiprtc_program
{
struct string_array
{
std::deque<std::string> strings{};
std::vector<const char*> c_strs{};
string_array() {}
string_array(const string_array&) = delete;
std::size_t size() const { return strings.size(); }
const char** data() { return c_strs.data(); }
void push_back(std::string s)
{
strings.push_back(std::move(s));
c_strs.push_back(strings.back().c_str());
}
};
hiprtc_program_ptr prog = nullptr;
string_array headers{};
string_array include_names{};
std::string cpp_src = "";
std::string cpp_name = "";
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
: cpp_src(src), cpp_name(name)
{
create_program();
}
hiprtc_program(std::vector<src_file> srcs)
{
for(auto&& src : srcs)
{
if(EndsWith(src.path, ".cpp"))
{
cpp_src = std::move(src.content);
cpp_name = std::move(src.path);
}
else
{
headers.push_back(std::move(src.content));
include_names.push_back(std::move(src.path));
}
}
create_program();
}
void create_program()
{
assert(not cpp_src.empty());
assert(not cpp_name.empty());
assert(headers.size() == include_names.size());
prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(),
headers.size(),
headers.data(),
include_names.data());
}
void compile(const std::vector<std::string>& options, bool quiet = false) const
{
std::vector<const char*> c_options;
std::transform(options.begin(),
options.end(),
std::back_inserter(c_options),
[](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log();
if(not prog_log.empty() and not quiet)
{
std::cerr << prog_log << std::endl;
}
if(result != HIPRTC_SUCCESS)
throw std::runtime_error("Compilation failed.");
}
std::string log() const
{
std::size_t n = 0;
hiprtc_check_error(hiprtcGetProgramLogSize(prog.get(), &n));
if(n == 0)
return {};
std::string buffer(n, '\0');
hiprtc_check_error(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() != 0);
return buffer;
}
std::vector<char> get_code_obj() const
{
std::size_t n = 0;
hiprtc_check_error(hiprtcGetCodeSize(prog.get(), &n));
std::vector<char> buffer(n);
hiprtc_check_error(hiprtcGetCode(prog.get(), buffer.data()));
return buffer;
}
};
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(const std::vector<src_file>& srcs,
const compile_options& options)
{
hiprtc_program prog(srcs);
auto flags = SplitString(options.flags, ' ');
prog.compile(flags);
return {prog.get_code_obj()};
}
static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{
options.flags += " -I. -O3";
options.flags += " -std=c++17";
options.flags += " --offload-arch=" + get_device_name();
auto cos = compile_hip_src_with_hiprtc(srcs, options);
if(cos.size() != 1)
std::runtime_error("No code object");
auto& obj = cos.front();
return kernel{obj.data(), options.kernel_name};
}
#endif
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{
#ifdef HIPRTC_FOR_CODEGEN_TESTS
return hiprtc_compile_kernel(srcs, options);
#else
return clang_compile_kernel(srcs, options);
#endif
}
} // namespace rtc