diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 45c47672b0..9e7c360f54 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -46,7 +46,6 @@ rocm_install_targets( TARGETS ck_host ck_headers EXPORT ck_host_targets INCLUDE include - PRIVATE ) rocm_export_targets( EXPORT ck_host_targets diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp new file mode 100644 index 0000000000..3f0b8bfe6a --- /dev/null +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -0,0 +1,87 @@ +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "common.hpp" +#include +#include +#include +#include + +using half = _Float16; + +const std::string gemm_compile_check = R"__ck__( +#include <${include}> + +extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, const ck::half_t* b1, ck::half_t* c) { + using G = ${template}; + constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${k}), ck::make_tuple(${m}, 1)), + ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(${n}, 1)), + ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${o}), ck::make_tuple(1, ${n})), + ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${o}), ck::make_tuple(${m}, 1))); + + static_assert(desc.IsValid(), "Invalid ck gemm."); + + if constexpr(desc.IsValid()) + { + ${template}::Run(desc, + 1.0, + a, + b, + b1, + c); + } +} + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + ck::host::device_batched_gemm_softmax_gemm::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + prob.O = 1024; + prob.TransB = true; + check_all check1, check2; + auto a = to_gpu(generate_buffer(1024 * 1024, 0)); + auto b = to_gpu(generate_buffer(1024 * 1024, 1)); + auto b1 = to_gpu(generate_buffer(1024 * 1024, 2)); + auto c = to_gpu(generate_buffer(1024 * 1024, 3)); + + std::string epilogue = ""; + std::string prologue = ""; + + auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue); + std::cout << "Num solutions: " << solutions.size() << std::endl; + for(auto i = 0; i < solutions.size(); ++i) + { + std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; + auto&& solution = solutions[i]; + auto src = ck::host::InterpolateString(gemm_compile_check, + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}, + {"o", std::to_string(prob.O)}}); + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + options.kernel_name = "f"; + auto k = rtc::compile_kernel(srcs, options); + auto block_size = solution.GetTemplateParameter("BlockSize"); + auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("Gemm1NPerBlock"); + auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * + ck::host::integer_divide_ceil(prob.N, n_per_block); + k.launch(nullptr, grid_size * block_size, block_size)( + a.data(), b.data(), b1.data(), c.data()); + + if(solution.GetTemplateParameter("MaskOutUpperTriangle")) + CHECK(report(solution, check1(rtc::from_gpu(c)))); + else + CHECK(report(solution, check2(rtc::from_gpu(c)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index 9e2d990d9b..2a383fc1c8 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -6,134 +6,24 @@ #include "ck/host/headers.hpp" #include "ck/host/stringutils.hpp" #include "ck/host/utils.hpp" -#include -#include -#include -#include -#include +#include "common.hpp" #include #include +#include +#include +#include #include +#include +#include using half = _Float16; -// using half = __fp16; - -std::vector get_headers_for_test() -{ - std::vector result; - auto hs = ck::host::GetHeaders(); - std::transform( - hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { - return {p.first, p.second}; - }); - return result; -} - -template -rtc::buffer generate_buffer(std::size_t n, std::size_t seed = 0) -{ - rtc::buffer result(n); - std::mt19937 gen(seed); - std::uniform_real_distribution dis(-1.0); - std::generate(result.begin(), result.end(), [&] { return dis(gen); }); - return result; -} - -template -bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) -{ - return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { - return fabs(x - y) < atol + rtol * fabs(y); - }); -} - -std::string classify(double x) -{ - switch(std::fpclassify(x)) - { - case FP_INFINITE: return "inf"; - case FP_NAN: return "nan"; - case FP_NORMAL: return "normal"; - case FP_SUBNORMAL: return "subnormal"; - case FP_ZERO: return "zero"; - default: return "unknown"; - } -} - -template -void print_classification(const Buffer& x) -{ - std::unordered_set result; - for(const auto& i : x) - result.insert(classify(i)); - for(const auto& c : result) - std::cout << c << ", "; - std::cout << std::endl; -} - -template -void print_statistics(const Buffer& x) -{ - std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; - std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; - double num_elements = x.size(); - auto mean = - std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; - auto stddev = std::sqrt( - std::accumulate(x.begin(), - x.end(), - double{0.0}, - [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / - num_elements); - std::cout << "Mean: " << mean << ", "; - std::cout << "StdDev: " << stddev << "\n"; -} - -template -void print_preview(const Buffer& x) -{ - if(x.size() <= 10) - { - std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); - } - else - { - std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); - std::cout << "..., "; - std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); - } - std::cout << std::endl; -} - -template -struct check_all -{ - rtc::buffer data{}; - bool operator()(const rtc::buffer& x) - { - if(data.empty()) - { - data = x; - return true; - } - if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); })) - return false; - return allclose(data, x); - } -}; - -template -auto report(const Solution& solution, bool pass) -{ - return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); -} const std::string gemm_compile_check = R"__ck__( #include <${include}> extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) { using G = ${template}; - constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), + constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})), ck::make_tuple(), ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n}))); @@ -166,15 +56,19 @@ TEST_CASE(test_problem_kernel) std::string epilogue = ""; std::string prologue = ""; - for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue)) + auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue); + std::cout << "Num solutions: " << solutions.size() << std::endl; + for(auto i = 0; i < solutions.size(); ++i) { - auto src = ck::host::InterpolateString(gemm_compile_check, + std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; + auto&& solution = solutions[i]; + auto src = ck::host::InterpolateString(gemm_compile_check, {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}, {"m", std::to_string(prob.M)}, {"n", std::to_string(prob.N)}, {"k", std::to_string(prob.K)}}); - auto srcs = get_headers_for_test(); + auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; options.kernel_name = "f"; diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp index 24fde2e523..b3be592e74 100644 --- a/codegen/test/include/common.hpp +++ b/codegen/test/include/common.hpp @@ -2,27 +2,38 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once + +#include "ck/host/headers.hpp" +#include +#include +#include #include #include #include #include #include -#include -#include -#include -#include +#include -std::vector get_headers_for_test() +inline std::vector create_headers_for_test() { + auto ck_headers = ck::host::GetHeaders(); std::vector result; - auto hs = ck::host::GetHeaders(); - std::transform( - hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { - return {p.first, p.second}; - }); + std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) { + std::string content; + content.reserve(p.second.size() + 1); + content.push_back(' '); // We need a whitespace before the content for hipRTC to work + content.append(p.second.data(), p.second.size()); + return rtc::src_file{p.first, std::move(content)}; + }); return result; } +inline const std::vector& get_headers_for_test() +{ + static const std::vector headers = create_headers_for_test(); + return headers; +} + template std::size_t GetSize(V mLens, V mStrides) { @@ -37,18 +48,24 @@ std::size_t GetSize(V mLens, V mStrides) return space; } -template -rtc::buffer generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +template +rtc::buffer generate_buffer(std::size_t n, std::size_t seed = 0) { - std::size_t space = GetSize(mLens, mStrides); - rtc::buffer result(space); + rtc::buffer result(n); std::mt19937 gen(seed); std::uniform_real_distribution dis(-1.0); std::generate(result.begin(), result.end(), [&] { return dis(gen); }); - // std::fill(result.begin(), result.end(), 1); return result; } +template +std::enable_if_t, rtc::buffer> +generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +{ + std::size_t space = GetSize(mLens, mStrides); + return generate_buffer(space, seed); +} + template bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) { @@ -57,7 +74,7 @@ bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) }); } -std::string classify(double x) +inline std::string classify(double x) { switch(std::fpclassify(x)) { diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index 68bfc2467b..2e7ceb5648 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -4,3 +4,9 @@ add_library(ck_rtc ${RTC_SOURCES}) target_include_directories(ck_rtc PUBLIC include) target_link_libraries(ck_rtc PUBLIC hip::host) target_link_libraries(ck_rtc PUBLIC -lstdc++fs) + +option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) +if(USE_HIPRTC_FOR_CODEGEN_TESTS) + target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS) + message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") +endif() diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp index a49714f7c6..207f10a8e8 100644 --- a/codegen/test/rtc/include/rtc/compile_kernel.hpp +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -12,8 +12,9 @@ namespace rtc { struct src_file { + src_file(std::filesystem::path p, std::string c) : path{std::move(p)}, content{std::move(c)} {} fs::path path; - std::string_view content; + std::string content; }; struct compile_options @@ -22,7 +23,7 @@ struct compile_options std::string kernel_name = "main"; }; -kernel compile_kernel(const std::vector& src, +kernel compile_kernel(const std::vector& srcs, compile_options options = compile_options{}); } // namespace rtc diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 5a70f898e8..a8da88be09 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -3,14 +3,41 @@ #include #include +#ifdef HIPRTC_FOR_CODEGEN_TESTS +#include +#include +#endif #include -#include -#include -#include +#include #include +#include +#include +#include +#include +#include namespace rtc { +bool EndsWith(const std::string& value, const std::string& suffix) +{ + if(suffix.size() > value.size()) + return false; + else + return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin()); +} + +std::vector SplitString(const std::string& s, char delim) +{ + std::vector elems; + std::stringstream ss(s + delim); + std::string item; + while(std::getline(ss, item, delim)) + { + elems.push_back(item); + } + return elems; +} + template T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0) { @@ -62,7 +89,7 @@ std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device // TODO: undo after extracting the codeobj // std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; } -kernel compile_kernel(const std::vector& srcs, compile_options options) +kernel clang_compile_kernel(const std::vector& srcs, compile_options options) { assert(not srcs.empty()); tmp_dir td{"compile"}; @@ -103,4 +130,172 @@ kernel compile_kernel(const std::vector& srcs, compile_options options return kernel{obj.data(), options.kernel_name}; } +#ifdef HIPRTC_FOR_CODEGEN_TESTS + +std::string hiprtc_error(hiprtcResult err, const std::string& msg) +{ + return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); +} + +void hiprtc_check_error(hiprtcResult err, const std::string& msg = "") +{ + if(err != HIPRTC_SUCCESS) + throw std::runtime_error(hiprtc_error(err, msg)); +} + +struct hiprtc_src_file +{ + hiprtc_src_file() = default; + hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {} + std::string path; + std::string content; +}; + +void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); } +using hiprtc_program_ptr = RTC_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy); + +template +hiprtc_program_ptr hiprtc_program_create(Ts... xs) +{ + hiprtcProgram prog = nullptr; + auto result = hiprtcCreateProgram(&prog, xs...); + hiprtc_program_ptr p{prog}; + hiprtc_check_error(result, "Create program failed."); + return p; +} + +struct hiprtc_program +{ + struct string_array + { + std::deque strings{}; + std::vector c_strs{}; + + string_array() {} + string_array(const string_array&) = delete; + + std::size_t size() const { return strings.size(); } + + const char** data() { return c_strs.data(); } + + void push_back(std::string s) + { + strings.push_back(std::move(s)); + c_strs.push_back(strings.back().c_str()); + } + }; + + hiprtc_program_ptr prog = nullptr; + string_array headers{}; + string_array include_names{}; + std::string cpp_src = ""; + std::string cpp_name = ""; + + hiprtc_program(const std::string& src, const std::string& name = "main.cpp") + : cpp_src(src), cpp_name(name) + { + create_program(); + } + + hiprtc_program(std::vector srcs) + { + for(auto&& src : srcs) + { + if(EndsWith(src.path, ".cpp")) + { + cpp_src = std::move(src.content); + cpp_name = std::move(src.path); + } + else + { + headers.push_back(std::move(src.content)); + include_names.push_back(std::move(src.path)); + } + } + create_program(); + } + + void create_program() + { + assert(not cpp_src.empty()); + assert(not cpp_name.empty()); + assert(headers.size() == include_names.size()); + prog = hiprtc_program_create(cpp_src.c_str(), + cpp_name.c_str(), + headers.size(), + headers.data(), + include_names.data()); + } + + void compile(const std::vector& options, bool quiet = false) const + { + std::vector c_options; + std::transform(options.begin(), + options.end(), + std::back_inserter(c_options), + [](const std::string& s) { return s.c_str(); }); + auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); + auto prog_log = log(); + if(not prog_log.empty() and not quiet) + { + std::cerr << prog_log << std::endl; + } + if(result != HIPRTC_SUCCESS) + throw std::runtime_error("Compilation failed."); + } + + std::string log() const + { + std::size_t n = 0; + hiprtc_check_error(hiprtcGetProgramLogSize(prog.get(), &n)); + if(n == 0) + return {}; + std::string buffer(n, '\0'); + hiprtc_check_error(hiprtcGetProgramLog(prog.get(), buffer.data())); + assert(buffer.back() != 0); + return buffer; + } + + std::vector get_code_obj() const + { + std::size_t n = 0; + hiprtc_check_error(hiprtcGetCodeSize(prog.get(), &n)); + std::vector buffer(n); + hiprtc_check_error(hiprtcGetCode(prog.get(), buffer.data())); + return buffer; + } +}; + +std::vector> compile_hip_src_with_hiprtc(const std::vector& srcs, + const compile_options& options) +{ + hiprtc_program prog(srcs); + auto flags = SplitString(options.flags, ' '); + prog.compile(flags); + return {prog.get_code_obj()}; +} + +static kernel hiprtc_compile_kernel(const std::vector& srcs, compile_options options) +{ + options.flags += " -I. -O3"; + options.flags += " -std=c++17"; + options.flags += " --offload-arch=" + get_device_name(); + auto cos = compile_hip_src_with_hiprtc(srcs, options); + if(cos.size() != 1) + std::runtime_error("No code object"); + auto& obj = cos.front(); + return kernel{obj.data(), options.kernel_name}; +} + +#endif + +kernel compile_kernel(const std::vector& srcs, compile_options options) +{ +#ifdef HIPRTC_FOR_CODEGEN_TESTS + return hiprtc_compile_kernel(srcs, options); +#else + return clang_compile_kernel(srcs, options); +#endif +} + } // namespace rtc diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index e9df8c9f5f..d61b5e2b27 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.15.0 +rocm-docs-core==1.17.0 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index a42fdf09bf..177f3ec184 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -199,7 +199,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.15.0 +rocm-docs-core==1.17.0 # via -r requirements.in rpds-py==0.22.3 # via diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 1ec0c6bc23..c8d1c20f4c 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -4,8 +4,9 @@ #pragma once #include "ck/config.h" + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include "ck/utility/env.hpp" -#ifndef CK_CODE_GEN_RTC #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 05dc491af7..e04e27b761 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -3,6 +3,7 @@ #pragma once +#ifndef __HIPCC_RTC__ #include #include #include @@ -97,3 +98,4 @@ inline bool is_gfx12_supported() } } // namespace ck +#endif diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 962f89e479..5c1c1c4e60 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#ifndef __HIPCC_RTC__ #include #include "ck/ck.hpp" @@ -166,3 +166,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, return 0; #endif } +#endif diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 774982d905..9285211519 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -3,11 +3,12 @@ #pragma once -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include #include #include + #include "ck/stream_config.hpp" #endif @@ -15,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #define GET_OBJECT_NAME_IMLP \ std::optional GetObjectName() const override \ { \ @@ -77,7 +78,7 @@ struct BaseOperator BaseOperator() = default; BaseOperator(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default; -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual std::string GetTypeString() const { return ""; } diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp index 09259224e7..204b09cad4 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp @@ -2,9 +2,10 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#ifndef __HIPCC_RTC__ #include #include +#endif #include "device_base.hpp" @@ -28,6 +29,7 @@ template // TODO: enum for mask type struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator { +#ifndef __HIPCC_RTC__ virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b0, @@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator CElementwiseOperation c_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 9006e70040..cf0184839e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -2,9 +2,11 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#ifndef __HIPCC_RTC__ #include +#endif +#include "ck/utility/array.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { @@ -34,6 +36,7 @@ struct DeviceGemmMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); +#ifndef __HIPCC_RTC__ virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, @@ -51,6 +54,7 @@ struct DeviceGemmMultipleD : public BaseOperator CDEElementwiseOperation cde_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif }; // GEMM: @@ -76,6 +80,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); +#ifndef __HIPCC_RTC__ virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, @@ -94,6 +99,52 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator CDEElementwiseOperation cde_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif +}; + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; }; // GEMM: diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index 997dcb75a6..8824f44ec5 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -28,8 +28,7 @@ enum struct GemmSpecialization NKOPadding, MNKOPadding, }; - -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) inline std::string getGemmSpecializationString(const GemmSpecialization& s) { switch(s) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index ea5a5d0e16..b4ab96d397 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -3,8 +3,12 @@ #pragma once +#ifndef __HIPCC_RTC__ #include #include +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -15,8 +19,6 @@ #include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -429,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle matrix_padder.PadN, MaskOutUpperTriangle>; +#ifndef __HIPCC_RTC__ // Argument struct Argument : public BaseArgument { @@ -603,6 +606,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return Run(*dynamic_cast(p_arg), stream_config); } }; +#endif static constexpr bool IsValidCompilationParameter() { @@ -610,6 +614,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } +#ifndef __HIPCC_RTC__ static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) { @@ -837,6 +842,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return str.str(); } +#endif template struct Descriptor diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index e6466a487b..3fae3a3765 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -3,8 +3,12 @@ #pragma once +#ifndef __HIPCC_RTC__ #include #include +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -14,8 +18,6 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" namespace ck { @@ -224,9 +226,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD& MRaws, - const std::array& NRaws, - const std::array& DsStride) + static auto MakeDsGridDescriptor_M_N(const Array& MRaws, + const Array& NRaws, + const Array& DsStride) { return generate_tuple( [&](auto i) { @@ -308,6 +310,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; +#ifndef __HIPCC_RTC__ // Argument struct Argument : public BaseArgument { @@ -497,6 +500,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD LoopSchedToString{ - {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + std::map LoopSchedToString{{LoopScheduler::Default, "Default"}, + { LoopScheduler::Interwave, + "Interwave" }}; std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, - {PipelineVersion::v2, "v2"}}; + { PipelineVersion::v2, + "v2" }}; // clang-format off str << "DeviceGemmMultipleD_Xdl_CShuffle" @@ -708,6 +716,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD struct Descriptor @@ -846,7 +855,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp index bce75f8f34..3d7ffbd163 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp @@ -356,10 +356,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle throw std::runtime_error("todo: only v1 v2 and v3 support now"); } } - else - { - throw std::runtime_error("not call kernel function"); - } #if 0 else { diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index 0ec55984bc..9fe2f0d976 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -13,6 +13,7 @@ enum struct MaskingSpecialization MaskOutUpperTriangle }; +#ifndef __HIPCC_RTC__ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) { switch(s) @@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s default: return "Unrecognized specialization!"; } } +#endif struct MaskDisabledPredicate { @@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate template struct C0MatrixMask_impl { - __host__ __device__ C0MatrixMask_impl(index_t NRaw) + __host__ __device__ constexpr C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) { } diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index 4a44177838..e836e73a1d 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -436,7 +436,7 @@ struct G_NDHW : public BaseTensorLayout } // namespace convolution -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template < typename Layout, typename std::enable_if::value, bool>::type = false> diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index be4e68bffa..f1d0f9844d 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -697,7 +697,7 @@ struct FastGelu template __device__ void operator()(Y& y, const X& x) const; -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template <> __host__ void operator()(float& y, const float& x) const { @@ -709,7 +709,6 @@ struct FastGelu y = x / (1.f + emu); } #endif - // device code, use lower precision "__ocml_exp_f32" and "rcp" template <> __device__ void operator()(float& y, const float& x) const diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 2bc9ef87ac..64fad1ca48 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -8,7 +8,7 @@ #include "ck/utility/tuple.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include #endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index eb1eb533d7..060f6d5d15 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -473,7 +473,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); } -#ifdef CK_CODE_GEN_RTC +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) template __host__ __device__ static auto MakeDsGridDescriptor_M_N(const ck::Array& MRaws, @@ -486,6 +486,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const std::array& NRaws, const std::array& DsStride) #endif + { return generate_tuple( [&](auto i) { @@ -949,7 +950,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t K, const index_t StrideA, const index_t StrideB, -#ifdef CK_CODE_GEN_RTC +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) const ck::Array StrideDs, #else const std::array StrideDs, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 9dad66913a..f8de0a48e5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -2,7 +2,8 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#ifndef CK_CODE_GEN_RTC + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include #endif @@ -54,7 +55,7 @@ constexpr auto GridwiseGemmPipeline_Selector() } else { -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; #endif } @@ -62,7 +63,7 @@ constexpr auto GridwiseGemmPipeline_Selector() } // namespace ck -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) { switch(p) diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 4f20487b9b..8c0b950941 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -780,7 +780,6 @@ struct mfma_type } }; -// TODO: fix mfma...f8f6f4 instructions template <> struct mfma_type { @@ -847,9 +846,14 @@ struct mfma_type // clang-format on template - __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + __device__ void run(const FloatA& a, + const int32_t scale_a, + const FloatB& b, + const int32_t scale_b, + FloatC& reg_c) const { - intrin_mfma_scale_f32_32x32x64f8f6f4::Run(a, b, reg_c); + intrin_mfma_scale_f32_32x32x64f8f6f4::Run( + a, scale_a, b, scale_b, reg_c); } }; @@ -871,9 +875,14 @@ struct mfma_type // clang-format on template - __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + __device__ void run(const FloatA& a, + const int32_t scale_a, + const FloatB& b, + const int32_t scale_b, + FloatC& reg_c) const { - intrin_mfma_scale_f32_16x16x128f8f6f4::Run(a, b, reg_c); + intrin_mfma_scale_f32_16x16x128f8f6f4::Run( + a, scale_a, b, scale_b, reg_c); } }; diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 328e37d009..317f324e6d 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1008,6 +1008,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); +#ifndef __HIPCC_RTC__ template __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, const index_t global_offset, @@ -1059,5 +1060,6 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); #endif } +#endif } // namespace ck diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 128c8e9a2c..3604712837 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -7,7 +7,7 @@ #include "ck/utility/functional2.hpp" #include "ck/utility/math.hpp" -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include #include diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index b125e3adf6..010b7aabd3 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -533,9 +533,9 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> reg_c.template AsType()[Number<0>{}], 0, // cbsz 0, // blgp - 0, // { OPSEL_HI[0], OPSEL[0] }? + 0, // OPSEL scale_a, - 0, // { OPSEL_HI[1], OPSEL[1] }? + 0, // OPSEL scale_b); #else ignore = reg_a; @@ -569,9 +569,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> reg_c.template AsType()[Number<0>{}], 0, // cbsz 0, // blgp - 0, // { OPSEL_HI[0], OPSEL[0] }? + 0, // OPSEL scale_a, - 0, // { OPSEL_HI[1], OPSEL[1] }? + 0, // OPSEL scale_b); #else ignore = reg_a; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index f90fcf6791..2e3b09eae9 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -6,16 +6,20 @@ #include "ck/utility/amd_ck_fp8.hpp" #include "ck/utility/e8m0.hpp" #include "ck/utility/statically_indexed_array.hpp" -#ifdef CK_CODE_GEN_RTC + +/// Definitions from , conflict with +/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h. + +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) using int8_t = signed char; using uint8_t = unsigned char; using int16_t = signed short; using uint16_t = unsigned short; using float_t = float; -#endif -namespace ck { +#endif // __HIPCC_RTC__ -#ifdef CK_CODE_GEN_RTC +namespace ck { +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) using byte = unsigned char; #else using std::byte; @@ -2612,7 +2616,7 @@ using pk_i4x2_t = typename vector_type::type; using pk_i4x4_t = typename vector_type::type; using pk_i4x8_t = typename vector_type::type; -#ifdef CK_CODE_GEN_RTC +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) template struct NumericLimits; @@ -2825,6 +2829,118 @@ struct NumericLimits return bit_cast(binary_qnan); } }; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 + static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 + static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 + static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 + static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 + static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 + static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 + static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 + + __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } + __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } + __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_135() + { + return e8m0_bexp_t(binary_135); + } + __host__ __device__ static constexpr e8m0_bexp_t Binary_142() + { + return e8m0_bexp_t(binary_142); + } +}; #else template struct NumericLimits @@ -2959,7 +3075,6 @@ struct NumericLimits return bit_cast(binary_qnan); } }; -#endif template <> struct NumericLimits @@ -3072,6 +3187,7 @@ struct NumericLimits return e8m0_bexp_t(binary_142); } }; +#endif template struct NumericUtils diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index 6ba63fc761..9d5403ceb2 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -4,15 +4,7 @@ #pragma once namespace ck { - -#ifndef CK_CODE_GEN_RTC -template -using enable_if = std::enable_if; - -template -using enable_if_t = typename std::enable_if::type; - -#else +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) template struct enable_if { @@ -26,6 +18,12 @@ struct enable_if template using enable_if_t = typename enable_if::type; -#endif +#else +template +using enable_if = std::enable_if; + +template +using enable_if_t = typename std::enable_if::type; +#endif } // namespace ck diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp index 837ff66312..cbbce85007 100644 --- a/include/ck/utility/loop_scheduler.hpp +++ b/include/ck/utility/loop_scheduler.hpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_CODE_GEN_RTC +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #endif -#pragma once - #include "ck/utility/common_header.hpp" namespace ck { @@ -28,7 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler() } // namespace ck -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) { switch(s) diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 03eb7c646d..05ae9093e2 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/ck.hpp" +#include "data_type.hpp" #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" @@ -34,7 +35,7 @@ struct MagicDivision // WARNING: magic division is only applicable for division inside this range. // You should use the return value of CalculateMagicNumbers, if division is not inside this // range. The "else" logic below is to quiet down run-time error. - if(divisor >= 1 && divisor <= INT32_MAX) + if(divisor >= 1 && divisor <= ck::NumericLimits::Max()) { uint32_t shift = 0; for(shift = 0; shift < 32; ++shift) diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index b31b46fb5f..e235f51c93 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float); #endif // math functions for the host, some are implemented by calling C++ std functions -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); }; @@ -924,5 +924,23 @@ inline __device__ double expm1(double x) return expm1(x); }; +template +inline __device__ T cos(T x) +{ + return ck::type_convert(cosf(ck::type_convert(x))); +}; + +template <> +inline __device__ float cos(float x) +{ + return cosf(x); +}; + +template <> +inline __device__ double cos(double x) +{ + return cos(x); +}; + } // namespace math } // namespace ck diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 6061d48118..25dae4e335 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -3,7 +3,7 @@ #pragma once -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #endif @@ -902,7 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen::type; } // namespace ck -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template std::ostream& operator<<(std::ostream& os, const ck::Sequence) { diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 596c748a2a..b4f1545aa9 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -159,7 +159,7 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple& tuple) } } -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template using is_tuple = decltype(ck::declval().IsTuple()); #endif @@ -167,7 +167,7 @@ using is_tuple = decltype(ck::declval().IsTuple()); template __host__ __device__ constexpr auto IsNestedTuple(const Tuple&) { -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) return (is_detected::value || ...); #endif } diff --git a/include/ck/utility/type.hpp b/include/ck/utility/type.hpp index ef9326ae57..bde9c179ce 100644 --- a/include/ck/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -1,316 +1,313 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/ck.hpp" -#include "ck/utility/enable_if.hpp" -#include "ck/utility/integral_constant.hpp" - -namespace ck { -#ifdef CK_CODE_GEN_RTC -// NOLINTNEXTLINE -#define CK_BUILTIN_TYPE_TRAIT1(name) \ - template \ - struct name : bool_constant<__##name(T)> \ - { \ - } - -// NOLINTNEXTLINE -#define CK_BUILTIN_TYPE_TRAIT2(name) \ - template \ - struct name : bool_constant<__##name(T, U)> \ - { \ - } - -// NOLINTNEXTLINE -#define CK_BUILTIN_TYPE_TRAITN(name) \ - template \ - struct name : bool_constant<__##name(Ts...)> \ - { \ - } - -CK_BUILTIN_TYPE_TRAIT1(is_class); -CK_BUILTIN_TYPE_TRAIT1(is_pointer); -CK_BUILTIN_TYPE_TRAIT1(is_reference); -CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable); -CK_BUILTIN_TYPE_TRAIT1(is_unsigned); -CK_BUILTIN_TYPE_TRAIT2(is_base_of); - -template -struct remove_cv -{ - using type = T; -}; - -template -struct remove_cv : remove_cv -{ -}; - -template -struct remove_cv : remove_cv -{ -}; - -template -struct remove_reference -{ - typedef T type; -}; -template -struct remove_reference -{ - typedef T type; -}; -template -struct remove_reference -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; - -template -constexpr T&& forward(typename remove_reference::type& t_) noexcept -{ - return static_cast(t_); -} -template -constexpr T&& forward(typename remove_reference::type&& t_) noexcept -{ - return static_cast(t_); -} - -template -struct is_const : public integral_constant -{ -}; -template -struct is_const : public integral_constant -{ -}; -template -inline constexpr bool is_const_v = is_const::value; - -template -inline constexpr bool is_reference_v = is_reference::value; - -template -struct remove_const -{ - typedef T type; -}; -template -struct remove_const -{ - typedef T type; -}; -template -using remove_const_t = typename remove_const::type; -template -inline constexpr bool is_class_v = is_class::value; - -template -inline constexpr bool is_trivially_copyable_v = is_trivially_copyable::value; -// template -// T&& declval() noexcept; - -template -U private_declval(int); - -template -T private_declval(long); - -template -auto declval() noexcept -> decltype(private_declval(0)); - -template -using void_t = void; -#else -#include -#include -using std::declval; -using std::forward; -using std::is_base_of; -using std::is_class; -using std::is_class_v; -using std::is_const_v; -using std::is_pointer; -using std::is_reference; -using std::is_reference_v; -using std::is_trivially_copyable; -using std::is_trivially_copyable_v; -using std::is_unsigned; -using std::remove_const_t; -using std::remove_cv; -using std::remove_pointer; -using std::remove_reference; -using std::void_t; -#endif - -template -struct is_same : public integral_constant -{ -}; - -template -struct is_same : public integral_constant -{ -}; - -template -struct is_floating_point : public integral_constant -{ -}; - -template <> -struct is_floating_point : public integral_constant -{ -}; - -template <> -struct is_floating_point : public integral_constant -{ -}; -template <> -struct is_floating_point : public integral_constant -{ -}; - -template -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template -inline constexpr bool is_same_v = is_same::value; - -template -inline constexpr bool is_base_of_v = is_base_of::value; - -template -inline constexpr bool is_unsigned_v = is_unsigned::value; - -template -using remove_reference_t = typename remove_reference::type; - -template -using remove_reference_t = typename remove_reference::type; - -template -using remove_cv_t = typename remove_cv::type; -template -using remove_cvref_t = remove_cv_t>; - -template -using remove_pointer_t = typename remove_pointer::type; - -template -inline constexpr bool is_pointer_v = is_pointer::value; - -template ::type = false> -__host__ __device__ constexpr Y bit_cast(const X& x) -{ - static_assert(__has_builtin(__builtin_bit_cast), ""); - static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type"); - - return __builtin_bit_cast(Y, x); -} -} // namespace ck +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/enable_if.hpp" +#include "ck/utility/integral_constant.hpp" + +namespace ck { +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) +// NOLINTNEXTLINE +#define CK_BUILTIN_TYPE_TRAIT1(name) \ + template \ + struct name : bool_constant<__##name(T)> \ + { \ + } + +// NOLINTNEXTLINE +#define CK_BUILTIN_TYPE_TRAIT2(name) \ + template \ + struct name : bool_constant<__##name(T, U)> \ + { \ + } + +// NOLINTNEXTLINE +#define CK_BUILTIN_TYPE_TRAITN(name) \ + template \ + struct name : bool_constant<__##name(Ts...)> \ + { \ + } + +CK_BUILTIN_TYPE_TRAIT1(is_class); +CK_BUILTIN_TYPE_TRAIT1(is_pointer); +CK_BUILTIN_TYPE_TRAIT1(is_reference); +CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable); +CK_BUILTIN_TYPE_TRAIT1(is_unsigned); +CK_BUILTIN_TYPE_TRAIT2(is_base_of); + +template +struct remove_cv +{ + using type = T; +}; + +template +struct remove_cv : remove_cv +{ +}; + +template +struct remove_cv : remove_cv +{ +}; + +template +struct remove_reference +{ + typedef T type; +}; +template +struct remove_reference +{ + typedef T type; +}; +template +struct remove_reference +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; + +template +constexpr T&& forward(typename remove_reference::type& t_) noexcept +{ + return static_cast(t_); +} +template +constexpr T&& forward(typename remove_reference::type&& t_) noexcept +{ + return static_cast(t_); +} + +template +struct is_const : public integral_constant +{ +}; +template +struct is_const : public integral_constant +{ +}; +template +inline constexpr bool is_const_v = is_const::value; + +template +inline constexpr bool is_reference_v = is_reference::value; + +template +struct remove_const +{ + typedef T type; +}; +template +struct remove_const +{ + typedef T type; +}; +template +using remove_const_t = typename remove_const::type; +template +inline constexpr bool is_class_v = is_class::value; + +template +inline constexpr bool is_trivially_copyable_v = is_trivially_copyable::value; +// template +// T&& declval() noexcept; + +template +U private_declval(int); + +template +T private_declval(long); + +template +auto declval() noexcept -> decltype(private_declval(0)); + +template +using void_t = void; +#else +#include +#include +using std::declval; +using std::forward; +using std::is_base_of; +using std::is_class; +using std::is_class_v; +using std::is_const_v; +using std::is_pointer; +using std::is_reference; +using std::is_reference_v; +using std::is_trivially_copyable; +using std::is_trivially_copyable_v; +using std::is_unsigned; +using std::remove_const_t; +using std::remove_cv; +using std::remove_pointer; +using std::remove_reference; +using std::void_t; +#endif + +template +struct is_same : public integral_constant +{ +}; + +template +struct is_same : public integral_constant +{ +}; + +template +struct is_floating_point : public integral_constant +{ +}; + +template <> +struct is_floating_point : public integral_constant +{ +}; + +template <> +struct is_floating_point : public integral_constant +{ +}; +template <> +struct is_floating_point : public integral_constant +{ +}; + +template +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template +inline constexpr bool is_same_v = is_same::value; + +template +inline constexpr bool is_base_of_v = is_base_of::value; + +template +inline constexpr bool is_unsigned_v = is_unsigned::value; + +template +using remove_reference_t = typename remove_reference::type; + +template +using remove_cv_t = typename remove_cv::type; +template +using remove_cvref_t = remove_cv_t>; + +template +using remove_pointer_t = typename remove_pointer::type; + +template +inline constexpr bool is_pointer_v = is_pointer::value; + +template ::type = false> +__host__ __device__ constexpr Y bit_cast(const X& x) +{ + static_assert(__has_builtin(__builtin_bit_cast), ""); + static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type"); + + return __builtin_bit_cast(Y, x); +} +} // namespace ck diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index e9b2e3fff2..cf862ae640 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -279,7 +279,6 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(half_t x) constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr int seed = 1254739; - #ifndef CK_CODE_GEN_RTC uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else @@ -344,7 +343,6 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(half_t x constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr int seed = 1254739; - #ifndef CK_CODE_GEN_RTC uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else @@ -1981,7 +1979,7 @@ inline __host__ __device__ float32_t type_convert(bf6x32_t #endif } -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template inline __host__ __device__ void array_convert(std::array& y, const std::array& x) diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index cc612794f4..f65e89bb82 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init) constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; - const auto mx_mfma_kernel = ck::matmul; + const auto mfma_kernel = ck::matmul; bool pass = true; - pass = ck::mfma_test::TestMFMA{}(mx_mfma_kernel, init); + BLOCK_K>{}(mfma_kernel, init); return pass; } TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 0; + auto AB_init = 4; auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP8MFMA32x32x64) { - auto AB_init = 0; + auto AB_init = 4; auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } + +/** + * @brief Run the test for the given MX MFMA instruction + * + * @param init - selects initialization algorithm for A and B tensors + */ +template +bool run_mxmfma_test(ck::index_t init) +{ + static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || + mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, + "Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported"); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + using AccType = float; // only MFMA_F32 instructions supported + using ScaleType = ck::e8m0_bexp_t; // biased exponent type + + ck::mfma_type(mfma)> mfma_instr; + constexpr auto BLOCK_M = mfma_instr.m_per_blk; + constexpr auto BLOCK_N = mfma_instr.n_per_blk; + constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; + constexpr auto BLOCK_X = 32; // scaling vector size + + const auto mx_mfma_kernel = + ck::matmul; + + bool pass = true; + + pass = ck::mxmfma_test::TestMXMFMA{}(mx_mfma_kernel, init); + + return pass; +} + +TEST(MXMFMA, MXFP8MFMA16x16x128) +{ + auto AB_init = 7; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP8MFMA32x32x64) +{ + auto AB_init = 7; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index e96e1b0b29..1f9091ebc5 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck/ck.hpp" @@ -7,7 +10,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" #include "ck/library/utility/check_err.hpp" namespace ck { @@ -18,7 +21,13 @@ enum class MFMA_F8F6F4 F32_16x16x128 = static_cast(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4 F32_32x32x64 = - static_cast(MfmaInstr::mfma_f32_32x32x64f8f6f4) // V_MFMA_F32_32X32X64_F8F6F4 + static_cast(MfmaInstr::mfma_f32_32x32x64f8f6f4), // V_MFMA_F32_32X32X64_F8F6F4 + + SCALE_F32_16x16x128 = static_cast( + MfmaInstr::mfma_scale_f32_16x16x128f8f6f4), // V_MFMA_SCALE_F32_16X16X128_F8F6F4 + SCALE_F32_32x32x64 = static_cast( + MfmaInstr::mfma_scale_f32_32x32x64f8f6f4) // V_MFMA_SCALE_F32_32X32X64_F8F6F4 + }; template @@ -32,6 +41,17 @@ struct mfma_type_selector auto op = mfma_type{}; op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); } + + __device__ void operator()(AFragT const& fragA, + const int32_t scale_a, + BFragT const& fragB, + const int32_t scale_b, + AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<16, 16, AFragT, BFragT, AccumFragT>( + fragA, scale_a, fragB, scale_b, fragAcc); + } }; template @@ -42,6 +62,17 @@ struct mfma_type_selector auto op = mfma_type{}; op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); } + + __device__ void operator()(AFragT const& fragA, + const int32_t scale_a, + BFragT const& fragB, + const int32_t scale_b, + AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<32, 32, AFragT, BFragT, AccumFragT>( + fragA, scale_a, fragB, scale_b, fragAcc); + } }; template @@ -52,151 +83,428 @@ static constexpr int32_t vectorSize(const VecT&) // Define a load function for input A blocks: // Size: (BLOCK_M x BLOCK_K) -// ASSUMPTION: -// - We want contiguous BLOCK_M sized column neighbors in register. -// - Data is in col_major format -// This means: -// - From A we will load K columns of size BLOCK_M to satisfy our input data +// - Data is in column major format +// - Rows are loaded in contiguous chunks that map to corresponding microscales +// - Each row is loaded in chunks of size 16 and each thread loads 32 elements template __device__ AFragT load_A_col_major(AType const* input_ptr) { // clang-format off // Register Mapping for 16x128: || Register Mapping for 32x64: - // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | - // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 | - // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector - // Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element - // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0] - // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1] - // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2] - // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3] - // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4] - // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5] - // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6] - // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7] - // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8] - // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9] - // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10] - // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11] - // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12] - // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13] - // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14] - // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15] - // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16] - // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17] - // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18] - // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19] - // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20] - // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21] - // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22] - // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23] - // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24] - // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25] - // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26] - // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27] - // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28] - // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29] - // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30] - // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31] + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | + // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | + // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | + // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | + // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | + // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | + // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | + // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | + // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | + // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | + // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | + // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | + // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | + // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | + // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | + // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | + // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | + // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | + // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | + // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | + // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | + // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | + // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | + // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | + // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | + // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | + // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | + // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | + // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | + // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | + // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | + // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | // clang-format on - // Here we want to load a BLOCK_M x BLOCK_K block of data. - static constexpr uint32_t VW = vectorSize(AFragT{}); - using ARawT = typename scalar_type::type; - using AScalarFragT = vector_type::type; + static constexpr int32_t WAVE_SIZE = 64; + + // Here we want to load from rows of A in chunks of 16 elements each. + static constexpr uint32_t chunk_size = 16; + + // each chunk is separated by offset + static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M; // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row - (threadIdx.x / BLOCK_M) * VW); // Col - auto stepCoord2D = std::make_pair(0u, 1u); + auto startCoord2D = + std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15} + (threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48} + + auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows + auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; // BLOCK_M is a stride in A matrix - auto startOffset = col_major(startCoord2D, BLOCK_M); - auto kOffset = col_major(stepCoord2D, BLOCK_M); + auto startOffset = col_major(startCoord2D, BLOCK_M); + auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M); + auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M); - // kOffset == BLOCK_M - // This means every BLOCK_M element is loaded into output vector - auto fragA = AScalarFragT{}; -#pragma unroll VW - for(uint32_t i = 0; i < VW; i++) + using ARawT = typename scalar_type::type; + using AScalarFragT = vector_type::type; + + AScalarFragT fragA{}; + +#pragma unroll + for(int chunk = 0; chunk < 2; chunk++) { - fragA[i] = bit_cast(input_ptr[startOffset + i * kOffset]); +#pragma unroll + for(uint32_t i = 0; i < chunk_size; i++) + { + fragA[chunk * chunk_size + i] = + bit_cast(input_ptr[startOffset + chunk * kMajorOffset + i * kMinorOffset]); + } } return fragA; } +// Define a load function for input A blocks: +// Size: (BLOCK_M x BLOCK_K) +// - Data is in row major format +// - Rows are loaded in contiguous chunks that map to corresponding microscales +// - Each row is loaded in chunks of size 16 and each thread loads 32 elements +template +__device__ AFragT load_A_row_major(AType const* input_ptr) +{ + // clang-format off + // Register Mapping for 16x128: || Register Mapping for 32x64: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | + // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | + // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | + // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | + // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | + // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | + // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | + // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | + // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | + // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | + // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | + // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | + // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | + // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | + // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | + // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | + // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | + // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | + // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | + // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | + // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | + // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | + // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | + // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | + // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | + // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | + // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | + // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | + // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | + // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | + // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | + // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | + // clang-format on + + static constexpr int32_t WAVE_SIZE = 64; + + // Here we want to load from rows of A in chunks of 16 elements each. + static constexpr uint32_t chunk_size = 16; + + // each chunk is separated by offset + static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M; + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 32 elements. + // We need to know where they start, and where the next elements are. + auto startCoord2D = + std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15} + (threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48} + + // auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows + auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + // BLOCK_K is a stride in A matrix + auto startOffset = row_major(startCoord2D, BLOCK_K); + // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K); + auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K); + + using ARawT = typename scalar_type::type; + using AScalarFragT = vector_type::type; + + union + { + AFragT frag; + AScalarFragT chunks[2]; + } fragA{}; + + auto* fragPtr = reinterpret_cast(input_ptr + startOffset); + fragA.chunks[0] = *fragPtr; + fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); + fragA.chunks[1] = *fragPtr; + + return fragA.frag; +} + +// Define a load function for scaled A blocks: +// Size: (BLOCK_M x BLOCK_K) +// ASSUMPTION: +// - The scale inputs distributed across 64 lanes. +template +__device__ AFragT load_mx_A_row_major(AType const* input_ptr, + ScaleType const* scale_ptr, + ScaleFragT& fragX) +{ + // clang-format off + // Register Mapping for 16x128: || Register Mapping for 32x64: + // Size | BLOCK_M | BLOCK_M | | BLOCK_M | BLOCK_M | | || Size | BLOCK_M | BLOCK_M | | | + // M | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || M | 0 ... 31 | 0 ... 31 | Vector | | + // Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale | + // Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------| + // Reg 0 [0:7] | K0 | K16 | x(M,0) | K32 | K48 | x(M,1) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(M,0) | + // Reg 0 [8:15] | K1 | K17 | x(M,0) | K33 | K49 | x(M,1) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(M,0) | + // Reg 0 [16:23] | K2 | K18 | x(M,0) | K34 | K50 | x(M,1) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(M,0) | + // Reg 0 [24:31] | K3 | K19 | x(M,0) | K35 | K51 | x(M,1) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(M,0) | + // Reg 1 [0:7] | K4 | K20 | x(M,0) | K36 | K52 | x(M,1) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(M,0) | + // Reg 1 [8:15] | K5 | K21 | x(M,0) | K37 | K53 | x(M,1) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(M,0) | + // Reg 1 [16:23] | K6 | K22 | x(M,0) | K38 | K54 | x(M,1) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(M,0) | + // Reg 1 [24:31] | K7 | K23 | x(M,0) | K39 | K55 | x(M,1) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(M,0) | + // Reg 2 [0:7] | K8 | K24 | x(M,0) | K40 | K56 | x(M,1) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(M,0) | + // Reg 2 [8:15] | K9 | K25 | x(M,0) | K41 | K57 | x(M,1) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(M,0) | + // Reg 2 [16:23] | K10 | K26 | x(M,0) | K42 | K58 | x(M,1) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(M,0) | + // Reg 2 [24:31] | K11 | K27 | x(M,0) | K43 | K59 | x(M,1) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(M,0) | + // Reg 3 [0:7] | K12 | K28 | x(M,0) | K44 | K60 | x(M,1) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(M,0) | + // Reg 3 [8:15] | K13 | K29 | x(M,0) | K45 | K61 | x(M,1) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(M,0) | + // Reg 3 [16:23] | K14 | K30 | x(M,0) | K46 | K62 | x(M,1) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(M,0) | + // Reg 3 [24:31] | K15 | K31 | x(M,0) | K47 | K63 | x(M,1) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(M,0) | + // Reg 4 [0:7] | K64 | K80 | x(M,2) | K96 | K112 | x(M,3) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(M,1) | + // Reg 4 [8:15] | K65 | K81 | x(M,2) | K97 | K113 | x(M,3) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(M,1) | + // Reg 4 [16:23] | K66 | K82 | x(M,2) | K98 | K114 | x(M,3) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(M,1) | + // Reg 4 [24:31] | K67 | K83 | x(M,2) | K99 | K115 | x(M,3) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(M,1) | + // Reg 5 [0:7] | K68 | K84 | x(M,2) | K100 | K116 | x(M,3) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(M,1) | + // Reg 5 [8:15] | K69 | K85 | x(M,2) | K101 | K117 | x(M,3) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(M,1) | + // Reg 5 [16:23] | K70 | K86 | x(M,2) | K102 | K118 | x(M,3) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(M,1) | + // Reg 5 [24:31] | K71 | K87 | x(M,2) | K103 | K119 | x(M,3) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(M,1) | + // Reg 6 [0:7] | K72 | K88 | x(M,2) | K104 | K120 | x(M,3) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(M,1) | + // Reg 6 [8:15] | K73 | K89 | x(M,2) | K105 | K121 | x(M,3) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(M,1) | + // Reg 6 [16:23] | K74 | K90 | x(M,2) | K106 | K122 | x(M,3) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(M,1) | + // Reg 6 [24:31] | K75 | K91 | x(M,2) | K107 | K123 | x(M,3) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(M,1) | + // Reg 7 [0:7] | K76 | K92 | x(M,2) | K108 | K124 | x(M,3) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(M,1) | + // Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) | + // Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) | + // Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) | + // clang-format on + static constexpr uint32_t VW = vectorSize(AFragT{}); + static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 1 element + // We need to know where they start + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row + (threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + // BLOCK_K / BLOCK_X is a stride in xA matrix + auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X); + + // obtain 8-bit exponent + fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + + return load_A_row_major(input_ptr); +} + // Define a load function for input B blocks: // Size: (BLOCK_K x BLOCK_N) -// ASSUMPTION: -// - We want contiguous BLOCK_N sized row neighbors in register. -// - Data is in row_major format -// This means: -// - From B we will load K rows of size BLOCK_N to satisfy our input data +// - Data is in col major format +// - Cols are loaded in contiguous chunks that map to corresponding microscales +// - Each col is loaded in chunks of size 16 and each thread loads 32 elements template __device__ BFragT load_B_col_major(BType const* input_ptr) { // clang-format off // Register Mapping for 128x16: || Register Mapping for 64x32: - // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | - // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || N | 0 ... 31 | 0 ... 31 | - // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector - // Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element - // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0] - // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1] - // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2] - // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3] - // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4] - // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5] - // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6] - // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7] - // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8] - // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9] - // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10] - // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11] - // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12] - // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13] - // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14] - // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15] - // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16] - // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17] - // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18] - // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19] - // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20] - // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21] - // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22] - // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23] - // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24] - // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25] - // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26] - // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27] - // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28] - // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29] - // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30] - // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31] + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | + // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | + // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | + // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | + // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | + // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | + // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | + // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | + // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | + // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | + // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | + // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | + // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | + // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | + // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | + // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | + // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | + // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | + // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | + // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | + // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | + // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | + // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | + // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | + // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | + // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | + // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | + // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | + // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | + // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | + // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | + // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | // clang-format on - // Here we want to load a BLOCK_K x BLOCK_N block of data. - static constexpr uint32_t VW = vectorSize(BFragT{}); + static constexpr int32_t WAVE_SIZE = 64; + + // Here we want to load from cols of B in chunks of 16 elements each. + static constexpr uint32_t chunk_size = 16; + + // each chunk is separated by an offset + static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64 // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row - threadIdx.x % BLOCK_N); // Col + auto startCoord2D = + std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48} + threadIdx.x % BLOCK_N); // Col {0-31} | {0-15} // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; - auto startOffset = col_major(startCoord2D, BLOCK_K); + // auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols + auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col - auto const* fragPtr = reinterpret_cast(input_ptr + startOffset); - return *fragPtr; + // BLOCK_K is a stride in B matrix + auto startOffset = col_major(startCoord2D, BLOCK_K); + // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K); + auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K); + + using BRawT = typename scalar_type::type; + using BScalarFragT = vector_type::type; + + union + { + BFragT frag; + BScalarFragT chunks[2]; + } fragB{}; + + auto* fragPtr = reinterpret_cast(input_ptr + startOffset); + fragB.chunks[0] = *fragPtr; + fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); + fragB.chunks[1] = *fragPtr; + + return fragB.frag; +} + +// Define a load function for scaled B blocks: +// Size: (BLOCK_K x BLOCK_N) +// ASSUMPTION: +// - The scale inputs distributed across 64 lanes. +template +__device__ BFragT load_mx_B_col_major(BType const* input_ptr, + ScaleType const* scale_ptr, + ScaleFragT& fragX) + +{ + // clang-format off + // Register Mapping for 128x16: || Register Mapping for 64x32: + // Size | BLOCK_N | BLOCK_N | | BLOCK_N | BLOCK_N | | || Size | BLOCK_N | BLOCK_N | | | + // N | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || N | 0 ... 31 | 0 ... 31 | Vector | | + // Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale | + // Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------| + // Reg 0 [0:7] | K0 | K16 | x(0,N) | K32 | K48 | x(1,N) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(0,N) | + // Reg 0 [8:15] | K1 | K17 | x(0,N) | K33 | K49 | x(1,N) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(0,N) | + // Reg 0 [16:23] | K2 | K18 | x(0,N) | K34 | K50 | x(1,N) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(0,N) | + // Reg 0 [24:31] | K3 | K19 | x(0,N) | K35 | K51 | x(1,N) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(0,N) | + // Reg 1 [0:7] | K4 | K20 | x(0,N) | K36 | K52 | x(1,N) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(0,N) | + // Reg 1 [8:15] | K5 | K21 | x(0,N) | K37 | K53 | x(1,N) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(0,N) | + // Reg 1 [16:23] | K6 | K22 | x(0,N) | K38 | K54 | x(1,N) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(0,N) | + // Reg 1 [24:31] | K7 | K23 | x(0,N) | K39 | K55 | x(1,N) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(0,N) | + // Reg 2 [0:7] | K8 | K24 | x(0,N) | K40 | K56 | x(1,N) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(0,N) | + // Reg 2 [8:15] | K9 | K25 | x(0,N) | K41 | K57 | x(1,N) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(0,N) | + // Reg 2 [16:23] | K10 | K26 | x(0,N) | K42 | K58 | x(1,N) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(0,N) | + // Reg 2 [24:31] | K11 | K27 | x(0,N) | K43 | K59 | x(1,N) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(0,N) | + // Reg 3 [0:7] | K12 | K28 | x(0,N) | K44 | K60 | x(1,N) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(0,N) | + // Reg 3 [8:15] | K13 | K29 | x(0,N) | K45 | K61 | x(1,N) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(0,N) | + // Reg 3 [16:23] | K14 | K30 | x(0,N) | K46 | K62 | x(1,N) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(0,N) | + // Reg 3 [24:31] | K15 | K31 | x(0,N) | K47 | K63 | x(1,N) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(0,N) | + // Reg 4 [0:7] | K64 | K80 | x(2,N) | K96 | K112 | x(3,N) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(1,N) | + // Reg 4 [8:15] | K65 | K81 | x(2,N) | K97 | K113 | x(3,N) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(1,N) | + // Reg 4 [16:23] | K66 | K82 | x(2,N) | K98 | K114 | x(3,N) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(1,N) | + // Reg 4 [24:31] | K67 | K83 | x(2,N) | K99 | K115 | x(3,N) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(1,N) | + // Reg 5 [0:7] | K68 | K84 | x(2,N) | K100 | K116 | x(3,N) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(1,N) | + // Reg 5 [8:15] | K69 | K85 | x(2,N) | K101 | K117 | x(3,N) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(1,N) | + // Reg 5 [16:23] | K70 | K86 | x(2,N) | K102 | K118 | x(3,N) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(1,N) | + // Reg 5 [24:31] | K71 | K87 | x(2,N) | K103 | K119 | x(3,N) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(1,N) | + // Reg 6 [0:7] | K72 | K88 | x(2,N) | K104 | K120 | x(3,N) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(1,N) | + // Reg 6 [8:15] | K73 | K89 | x(2,N) | K105 | K121 | x(3,N) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(1,N) | + // Reg 6 [16:23] | K74 | K90 | x(2,N) | K106 | K122 | x(3,N) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(1,N) | + // Reg 6 [24:31] | K75 | K91 | x(2,N) | K107 | K123 | x(3,N) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(1,N) | + // Reg 7 [0:7] | K76 | K92 | x(2,N) | K108 | K124 | x(3,N) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(1,N) | + // Reg 7 [8:15] | K77 | K93 | x(2,N) | K109 | K125 | x(3,N) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(1,N) | + // Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) | + // Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) | + + // clang-format on + static constexpr uint32_t VW = vectorSize(BFragT{}); + static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 1 element + // We need to know where to start + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row + threadIdx.x % BLOCK_N); // Col + + // Flatten to 1D col_major offsets. + auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; + + auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X); + + // obtain 8-bit exponent + fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + + return load_B_col_major(input_ptr); } // Define a store function for C @@ -309,6 +617,129 @@ struct store_C_col_major } }; +// Define a store function for C +// Size: (BLOCK_M x BLOCK_N) +// ASSUMPTION: +// - We want contiguous BLOCK_N sized row neighbors in register. +// - Data is in row major format +template +struct store_C_row_major; + +// Here we want to store a 16x16 block of data. +// +// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | +// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | +// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector +// Register Element ------------ ------------- ------------ -------------- Element +// Reg0 | M0 | M4 | M8 | M12 | v[0] +// Reg1 | M1 | M5 | M9 | M13 | v[1] +// Reg2 | M2 | M6 | M10 | M14 | v[2] +// Reg3 | M3 | M7 | M11 | M15 | v[3] +template +struct store_C_row_major +{ + __device__ void operator()(CType* output, CFragT cFrag) + { + static constexpr uint32_t VW = vectorSize(cFrag); // 4 + static constexpr uint32_t Dim = 16; + + // Each thread will load 4 elements. + // We need to know where they start, and where the next elements are. + auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row + threadIdx.x % Dim); // Col + auto stepCoord2D = std::make_pair(1u, 0u); + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + auto startOffset = row_major(startCoord2D, 16); + auto kOffset = row_major(stepCoord2D, 16); + + auto* fragPtr = reinterpret_cast(output + startOffset); + *fragPtr = cFrag; + + // If you notice carefully, kOffset != 1. + // This means the following is vector is updated with 4 non-contiguous offsets, + // which the compiler will separate into 4 different global_store_dword instructions. + output[startOffset] = cFrag[0]; // v[0] = Reg 0 + output[startOffset + kOffset] = cFrag[1]; // v[1] = Reg 1 + output[startOffset + 2 * kOffset] = cFrag[2]; // v[2] = Reg 2 + output[startOffset + 3 * kOffset] = cFrag[3]; // v[3] = Reg 3 + } +}; + +// Here we want to store a 32x32 block of data. +// Register Mapping: + +// Size | BLOCK_N | BLOCK_N | +// N | 0 ... 31 | 0 ... 31 | +// Thread Id | 0 ... 31 | 32 ... 63 | Vector +// Register Element ------------ ------------- Element +// Reg0 | M0 | M4 | v[0] +// Reg1 | M1 | M5 | v[1] +// Reg2 | M2 | M6 | v[2] +// Reg3 | M3 | M7 | v[3] +// ____________ _____________ +// Reg4 | M8 | M12 | v[4] +// Reg5 | M9 | M13 | v[5] +// Reg6 | M10 | M14 | v[6] +// Reg7 | M11 | M15 | v[7] +// ____________ _____________ +// Reg8 | M16 | M20 | v[8] +// Reg9 | M17 | M21 | v[9] +// Reg10 | M18 | M22 | v[10] +// Reg11 | M19 | M23 | v[11] +// ____________ _____________ +// Reg12 | M24 | M28 | v[12] +// Reg13 | M25 | M29 | v[13] +// Reg14 | M26 | M30 | v[14] +// Reg15 | M27 | M31 | v[15] + +template +struct store_C_row_major +{ + __device__ void operator()(CType* output, CFragT cFrag) + { + static constexpr uint32_t WAVE_SIZE = 64; + static constexpr uint32_t VW = 4; // This VW is per 'chunk' + static constexpr uint32_t Dim = 32; // BLOCK_N + static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8 + + auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row + threadIdx.x % Dim); // Col + + // Minor step for each 'chunk' + auto minorStepCoord2D = std::make_pair(1u, 0u); + + // Major step between 'chunks' + auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0); + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + auto startOffset = row_major(startCoord2D, 32); + auto kMinorOffset = row_major(minorStepCoord2D, 32); + auto kMajorOffset = row_major(majorStepCoord2D, 32); + + output[startOffset] = cFrag[0]; // v[0] = Reg 0 + output[startOffset + kMinorOffset] = cFrag[1]; // v[1] = Reg 1 + output[startOffset + 2 * kMinorOffset] = cFrag[2]; // v[2] = Reg 2 + output[startOffset + 3 * kMinorOffset] = cFrag[3]; // v[3] = Reg 3 + output[startOffset + kMajorOffset] = cFrag[4]; // v[4] = Reg 4 + output[startOffset + kMajorOffset + kMinorOffset] = cFrag[5]; // v[5] = Reg 5 + output[startOffset + kMajorOffset + 2 * kMinorOffset] = cFrag[6]; // v[6] = Reg 6 + output[startOffset + kMajorOffset + 3 * kMinorOffset] = cFrag[7]; // v[7] = Reg 7 + output[startOffset + 2 * kMajorOffset] = cFrag[8]; // v[8] = Reg 8 + output[startOffset + 2 * kMajorOffset + kMinorOffset] = cFrag[9]; // v[9] = Reg 9 + output[startOffset + 2 * kMajorOffset + 2 * kMinorOffset] = cFrag[10]; // v[10] = Reg 10 + output[startOffset + 2 * kMajorOffset + 3 * kMinorOffset] = cFrag[11]; // v[11] = Reg 11 + output[startOffset + 3 * kMajorOffset] = cFrag[12]; // v[12] = Reg 12 + output[startOffset + 3 * kMajorOffset + kMinorOffset] = cFrag[13]; // v[13] = Reg 13 + output[startOffset + 3 * kMajorOffset + 2 * kMinorOffset] = cFrag[14]; // v[14] = Reg 14 + output[startOffset + 3 * kMajorOffset + 3 * kMinorOffset] = cFrag[15]; // v[15] = Reg 15 + } +}; + template {}; storeC(c, fragC); } + +template +__global__ void +matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c) +{ + constexpr int WAVE_SIZE = 64; + assert(threadIdx.x < WAVE_SIZE); + assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); + + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; + using AccumFragT = vector_type; + using RawAccumFragT = vector_type::type; + using ScaleFragT = int32_t; + + // Create frags + auto fragA = AFragT{}; + auto fragB = BFragT{}; + auto fragC = CFragT{}; + auto fragAcc = AccumFragT{0}; + auto fragXa = ScaleFragT{0}; + auto fragXb = ScaleFragT{0}; + + // Load the inputs. + // A = col major, BLOCK_M x BLOCK_K + fragA = load_mx_A_row_major( + a, xa, fragXa); + + // B = col major, BLOCK_K x BLOCK_N + fragB = load_mx_B_col_major( + b, xb, fragXb); + + // Scaled Matrix multiply-accumulate using MFMA units + // Accumulation intermediate = BLOCK_M x BLOCK_N + mfma_type_selector{}( + fragA, fragXa, fragB, fragXb, fragAcc); + + for(int i = 0; i < vectorSize(fragC); ++i) + { + fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); + } + + auto storeC = store_C_row_major{}; + storeC(c, fragC); +} + /** * @brief Structure to hold dimension parameters for GEMM tensors. * @@ -373,6 +859,225 @@ struct GemmParams ck::index_t StrideC = -1; }; +namespace mxmfma_test { +template +void RunHostGEMM(const Tensor& A, + const Tensor& a_scales, + const Tensor& B, + const Tensor& b_scales, + Tensor& C) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + A, a_scales, B, b_scales, C, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(KernelType kernel, + const Tensor& A, + const Tensor& a_scales, + const Tensor& B, + const Tensor& b_scales, + Tensor& C) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem a_scales_device_buf(sizeof(ScaleType) * a_scales.mDesc.GetElementSpaceSize()); + DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem b_scales_device_buf(sizeof(ScaleType) * b_scales.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + a_scales_device_buf.ToDevice(a_scales.mData.data()); + b_n_k_device_buf.ToDevice(B.mData.data()); + b_scales_device_buf.ToDevice(b_scales.mData.data()); + kernel<<<1, 64>>>(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(a_scales_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(b_scales_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer())); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; +} + +template +struct TestMXMFMA +{ + auto PrepareGemmTensors(const GemmParams& params, index_t init) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor a_scales( + f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{})); + Tensor b_n_k( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor b_scales( + f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + switch(init) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_scales.GenerateTensorValue( + GeneratorTensor_1{ScaleType{0.015625f}}); // 1/64 + // NOTE: not all numbers are representable in FP8, BF8, etc. + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); + break; + case 1: + // results in C = {K} + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{512.0f}}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f / 512}}); + break; + case 2: + // expect small round off errors + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + + b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); + break; + + case 3: + // expect small round off errors + a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); + a_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + b_n_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); + b_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + break; + default: + // all initial values are representable in FP8, BF8 + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + a_scales.GenerateTensorValue( + GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] + b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + b_scales.GenerateTensorValue( + GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] + + break; + } + + return std::make_tuple( + a_m_k, a_scales, b_n_k, b_scales, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceMFMA& mfma_kernel, index_t init) + { + // Arrange + GemmParams params; + params.M = BLOCK_M; + params.N = BLOCK_N; + params.K = BLOCK_K; + + auto f_get_default_stride = [](std::size_t row, + std::size_t col, + ck::index_t stride, + auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{}); + params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{}); + params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{}); + + auto host_tensors = PrepareGemmTensors(params, init); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& a_scales = std::get<1>(host_tensors); + const Tensor& b = std::get<2>(host_tensors); + const Tensor& b_scales = std::get<3>(host_tensors); + Tensor& c_host = std::get<4>(host_tensors); + Tensor& c_device = std::get<5>(host_tensors); + + RunHostGEMM(a, a_scales, b, b_scales, c_host); + + RunDeviceGEMM(mfma_kernel, a, a_scales, b, b_scales, c_device); + + bool res = false; + if constexpr(std::is_same::value || + std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + } + else + { + std::cout << "UNSUPPORTED CDataType" << std::endl; + } + + return res; + } +}; + +} // namespace mxmfma_test + namespace mfma_test { template