diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 45c47672b0..9e7c360f54 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -46,7 +46,6 @@ rocm_install_targets( TARGETS ck_host ck_headers EXPORT ck_host_targets INCLUDE include - PRIVATE ) rocm_export_targets( EXPORT ck_host_targets diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp new file mode 100644 index 0000000000..3f0b8bfe6a --- /dev/null +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -0,0 +1,87 @@ +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "common.hpp" +#include +#include +#include +#include + +using half = _Float16; + +const std::string gemm_compile_check = R"__ck__( +#include <${include}> + +extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, const ck::half_t* b1, ck::half_t* c) { + using G = ${template}; + constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${k}), ck::make_tuple(${m}, 1)), + ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(${n}, 1)), + ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${o}), ck::make_tuple(1, ${n})), + ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${o}), ck::make_tuple(${m}, 1))); + + static_assert(desc.IsValid(), "Invalid ck gemm."); + + if constexpr(desc.IsValid()) + { + ${template}::Run(desc, + 1.0, + a, + b, + b1, + c); + } +} + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + ck::host::device_batched_gemm_softmax_gemm::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + prob.O = 1024; + prob.TransB = true; + check_all check1, check2; + auto a = to_gpu(generate_buffer(1024 * 1024, 0)); + auto b = to_gpu(generate_buffer(1024 * 1024, 1)); + auto b1 = to_gpu(generate_buffer(1024 * 1024, 2)); + auto c = to_gpu(generate_buffer(1024 * 1024, 3)); + + std::string epilogue = ""; + std::string prologue = ""; + + auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue); + std::cout << "Num solutions: " << solutions.size() << std::endl; + for(auto i = 0; i < solutions.size(); ++i) + { + std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; + auto&& solution = solutions[i]; + auto src = ck::host::InterpolateString(gemm_compile_check, + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}, + {"o", std::to_string(prob.O)}}); + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + options.kernel_name = "f"; + auto k = rtc::compile_kernel(srcs, options); + auto block_size = solution.GetTemplateParameter("BlockSize"); + auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("Gemm1NPerBlock"); + auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * + ck::host::integer_divide_ceil(prob.N, n_per_block); + k.launch(nullptr, grid_size * block_size, block_size)( + a.data(), b.data(), b1.data(), c.data()); + + if(solution.GetTemplateParameter("MaskOutUpperTriangle")) + CHECK(report(solution, check1(rtc::from_gpu(c)))); + else + CHECK(report(solution, check2(rtc::from_gpu(c)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index 9e2d990d9b..2a383fc1c8 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -6,134 +6,24 @@ #include "ck/host/headers.hpp" #include "ck/host/stringutils.hpp" #include "ck/host/utils.hpp" -#include -#include -#include -#include -#include +#include "common.hpp" #include #include +#include +#include +#include #include +#include +#include using half = _Float16; -// using half = __fp16; - -std::vector get_headers_for_test() -{ - std::vector result; - auto hs = ck::host::GetHeaders(); - std::transform( - hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { - return {p.first, p.second}; - }); - return result; -} - -template -rtc::buffer generate_buffer(std::size_t n, std::size_t seed = 0) -{ - rtc::buffer result(n); - std::mt19937 gen(seed); - std::uniform_real_distribution dis(-1.0); - std::generate(result.begin(), result.end(), [&] { return dis(gen); }); - return result; -} - -template -bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) -{ - return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { - return fabs(x - y) < atol + rtol * fabs(y); - }); -} - -std::string classify(double x) -{ - switch(std::fpclassify(x)) - { - case FP_INFINITE: return "inf"; - case FP_NAN: return "nan"; - case FP_NORMAL: return "normal"; - case FP_SUBNORMAL: return "subnormal"; - case FP_ZERO: return "zero"; - default: return "unknown"; - } -} - -template -void print_classification(const Buffer& x) -{ - std::unordered_set result; - for(const auto& i : x) - result.insert(classify(i)); - for(const auto& c : result) - std::cout << c << ", "; - std::cout << std::endl; -} - -template -void print_statistics(const Buffer& x) -{ - std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; - std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; - double num_elements = x.size(); - auto mean = - std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; - auto stddev = std::sqrt( - std::accumulate(x.begin(), - x.end(), - double{0.0}, - [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / - num_elements); - std::cout << "Mean: " << mean << ", "; - std::cout << "StdDev: " << stddev << "\n"; -} - -template -void print_preview(const Buffer& x) -{ - if(x.size() <= 10) - { - std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); - } - else - { - std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); - std::cout << "..., "; - std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); - } - std::cout << std::endl; -} - -template -struct check_all -{ - rtc::buffer data{}; - bool operator()(const rtc::buffer& x) - { - if(data.empty()) - { - data = x; - return true; - } - if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); })) - return false; - return allclose(data, x); - } -}; - -template -auto report(const Solution& solution, bool pass) -{ - return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); -} const std::string gemm_compile_check = R"__ck__( #include <${include}> extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) { using G = ${template}; - constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), + constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})), ck::make_tuple(), ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n}))); @@ -166,15 +56,19 @@ TEST_CASE(test_problem_kernel) std::string epilogue = ""; std::string prologue = ""; - for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue)) + auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue); + std::cout << "Num solutions: " << solutions.size() << std::endl; + for(auto i = 0; i < solutions.size(); ++i) { - auto src = ck::host::InterpolateString(gemm_compile_check, + std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; + auto&& solution = solutions[i]; + auto src = ck::host::InterpolateString(gemm_compile_check, {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}, {"m", std::to_string(prob.M)}, {"n", std::to_string(prob.N)}, {"k", std::to_string(prob.K)}}); - auto srcs = get_headers_for_test(); + auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; options.kernel_name = "f"; diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp index 24fde2e523..b3be592e74 100644 --- a/codegen/test/include/common.hpp +++ b/codegen/test/include/common.hpp @@ -2,27 +2,38 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once + +#include "ck/host/headers.hpp" +#include +#include +#include #include #include #include #include #include -#include -#include -#include -#include +#include -std::vector get_headers_for_test() +inline std::vector create_headers_for_test() { + auto ck_headers = ck::host::GetHeaders(); std::vector result; - auto hs = ck::host::GetHeaders(); - std::transform( - hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { - return {p.first, p.second}; - }); + std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) { + std::string content; + content.reserve(p.second.size() + 1); + content.push_back(' '); // We need a whitespace before the content for hipRTC to work + content.append(p.second.data(), p.second.size()); + return rtc::src_file{p.first, std::move(content)}; + }); return result; } +inline const std::vector& get_headers_for_test() +{ + static const std::vector headers = create_headers_for_test(); + return headers; +} + template std::size_t GetSize(V mLens, V mStrides) { @@ -37,18 +48,24 @@ std::size_t GetSize(V mLens, V mStrides) return space; } -template -rtc::buffer generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +template +rtc::buffer generate_buffer(std::size_t n, std::size_t seed = 0) { - std::size_t space = GetSize(mLens, mStrides); - rtc::buffer result(space); + rtc::buffer result(n); std::mt19937 gen(seed); std::uniform_real_distribution dis(-1.0); std::generate(result.begin(), result.end(), [&] { return dis(gen); }); - // std::fill(result.begin(), result.end(), 1); return result; } +template +std::enable_if_t, rtc::buffer> +generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +{ + std::size_t space = GetSize(mLens, mStrides); + return generate_buffer(space, seed); +} + template bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) { @@ -57,7 +74,7 @@ bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) }); } -std::string classify(double x) +inline std::string classify(double x) { switch(std::fpclassify(x)) { diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index 68bfc2467b..2e7ceb5648 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -4,3 +4,9 @@ add_library(ck_rtc ${RTC_SOURCES}) target_include_directories(ck_rtc PUBLIC include) target_link_libraries(ck_rtc PUBLIC hip::host) target_link_libraries(ck_rtc PUBLIC -lstdc++fs) + +option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) +if(USE_HIPRTC_FOR_CODEGEN_TESTS) + target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS) + message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") +endif() diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp index a49714f7c6..207f10a8e8 100644 --- a/codegen/test/rtc/include/rtc/compile_kernel.hpp +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -12,8 +12,9 @@ namespace rtc { struct src_file { + src_file(std::filesystem::path p, std::string c) : path{std::move(p)}, content{std::move(c)} {} fs::path path; - std::string_view content; + std::string content; }; struct compile_options @@ -22,7 +23,7 @@ struct compile_options std::string kernel_name = "main"; }; -kernel compile_kernel(const std::vector& src, +kernel compile_kernel(const std::vector& srcs, compile_options options = compile_options{}); } // namespace rtc diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 5a70f898e8..a8da88be09 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -3,14 +3,41 @@ #include #include +#ifdef HIPRTC_FOR_CODEGEN_TESTS +#include +#include +#endif #include -#include -#include -#include +#include #include +#include +#include +#include +#include +#include namespace rtc { +bool EndsWith(const std::string& value, const std::string& suffix) +{ + if(suffix.size() > value.size()) + return false; + else + return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin()); +} + +std::vector SplitString(const std::string& s, char delim) +{ + std::vector elems; + std::stringstream ss(s + delim); + std::string item; + while(std::getline(ss, item, delim)) + { + elems.push_back(item); + } + return elems; +} + template T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0) { @@ -62,7 +89,7 @@ std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device // TODO: undo after extracting the codeobj // std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; } -kernel compile_kernel(const std::vector& srcs, compile_options options) +kernel clang_compile_kernel(const std::vector& srcs, compile_options options) { assert(not srcs.empty()); tmp_dir td{"compile"}; @@ -103,4 +130,172 @@ kernel compile_kernel(const std::vector& srcs, compile_options options return kernel{obj.data(), options.kernel_name}; } +#ifdef HIPRTC_FOR_CODEGEN_TESTS + +std::string hiprtc_error(hiprtcResult err, const std::string& msg) +{ + return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); +} + +void hiprtc_check_error(hiprtcResult err, const std::string& msg = "") +{ + if(err != HIPRTC_SUCCESS) + throw std::runtime_error(hiprtc_error(err, msg)); +} + +struct hiprtc_src_file +{ + hiprtc_src_file() = default; + hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {} + std::string path; + std::string content; +}; + +void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); } +using hiprtc_program_ptr = RTC_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy); + +template +hiprtc_program_ptr hiprtc_program_create(Ts... xs) +{ + hiprtcProgram prog = nullptr; + auto result = hiprtcCreateProgram(&prog, xs...); + hiprtc_program_ptr p{prog}; + hiprtc_check_error(result, "Create program failed."); + return p; +} + +struct hiprtc_program +{ + struct string_array + { + std::deque strings{}; + std::vector c_strs{}; + + string_array() {} + string_array(const string_array&) = delete; + + std::size_t size() const { return strings.size(); } + + const char** data() { return c_strs.data(); } + + void push_back(std::string s) + { + strings.push_back(std::move(s)); + c_strs.push_back(strings.back().c_str()); + } + }; + + hiprtc_program_ptr prog = nullptr; + string_array headers{}; + string_array include_names{}; + std::string cpp_src = ""; + std::string cpp_name = ""; + + hiprtc_program(const std::string& src, const std::string& name = "main.cpp") + : cpp_src(src), cpp_name(name) + { + create_program(); + } + + hiprtc_program(std::vector srcs) + { + for(auto&& src : srcs) + { + if(EndsWith(src.path, ".cpp")) + { + cpp_src = std::move(src.content); + cpp_name = std::move(src.path); + } + else + { + headers.push_back(std::move(src.content)); + include_names.push_back(std::move(src.path)); + } + } + create_program(); + } + + void create_program() + { + assert(not cpp_src.empty()); + assert(not cpp_name.empty()); + assert(headers.size() == include_names.size()); + prog = hiprtc_program_create(cpp_src.c_str(), + cpp_name.c_str(), + headers.size(), + headers.data(), + include_names.data()); + } + + void compile(const std::vector& options, bool quiet = false) const + { + std::vector c_options; + std::transform(options.begin(), + options.end(), + std::back_inserter(c_options), + [](const std::string& s) { return s.c_str(); }); + auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); + auto prog_log = log(); + if(not prog_log.empty() and not quiet) + { + std::cerr << prog_log << std::endl; + } + if(result != HIPRTC_SUCCESS) + throw std::runtime_error("Compilation failed."); + } + + std::string log() const + { + std::size_t n = 0; + hiprtc_check_error(hiprtcGetProgramLogSize(prog.get(), &n)); + if(n == 0) + return {}; + std::string buffer(n, '\0'); + hiprtc_check_error(hiprtcGetProgramLog(prog.get(), buffer.data())); + assert(buffer.back() != 0); + return buffer; + } + + std::vector get_code_obj() const + { + std::size_t n = 0; + hiprtc_check_error(hiprtcGetCodeSize(prog.get(), &n)); + std::vector buffer(n); + hiprtc_check_error(hiprtcGetCode(prog.get(), buffer.data())); + return buffer; + } +}; + +std::vector> compile_hip_src_with_hiprtc(const std::vector& srcs, + const compile_options& options) +{ + hiprtc_program prog(srcs); + auto flags = SplitString(options.flags, ' '); + prog.compile(flags); + return {prog.get_code_obj()}; +} + +static kernel hiprtc_compile_kernel(const std::vector& srcs, compile_options options) +{ + options.flags += " -I. -O3"; + options.flags += " -std=c++17"; + options.flags += " --offload-arch=" + get_device_name(); + auto cos = compile_hip_src_with_hiprtc(srcs, options); + if(cos.size() != 1) + std::runtime_error("No code object"); + auto& obj = cos.front(); + return kernel{obj.data(), options.kernel_name}; +} + +#endif + +kernel compile_kernel(const std::vector& srcs, compile_options options) +{ +#ifdef HIPRTC_FOR_CODEGEN_TESTS + return hiprtc_compile_kernel(srcs, options); +#else + return clang_compile_kernel(srcs, options); +#endif +} + } // namespace rtc diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index e9df8c9f5f..d61b5e2b27 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.15.0 +rocm-docs-core==1.17.0 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index a42fdf09bf..177f3ec184 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -199,7 +199,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.15.0 +rocm-docs-core==1.17.0 # via -r requirements.in rpds-py==0.22.3 # via diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp index dbdf7199e8..01542c4775 100755 --- a/example/01_gemm/gemm_xdl_streamk.cpp +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -27,11 +27,15 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_AMD_MFMA_GFX950) + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +#else // defined(CK_USE_AMD_MFMA_GFX950) < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>; +#endif // defined(CK_USE_AMD_MFMA_GFX950) diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index 497ea19e11..6fbaee7dba 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -13,3 +13,9 @@ add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bw add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16) + +add_example_executable(example_grouped_conv_bwd_weight_v3_xdl_bf16 grouped_conv_bwd_weight_v3_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_xdl_bf16) + +add_example_executable(example_grouped_conv_bwd_weight_v3_xdl_fp16 grouped_conv_bwd_weight_v3_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_xdl_fp16) diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp new file mode 100644 index 0000000000..c0168bf117 --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" + +using InDataType = BF16; +// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory +using WeiDataType = F32; +using OutDataType = BF16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + // clang-format on + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 32, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl + // clang-format off + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp new file mode 100644 index 0000000000..4b9fd356df --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 32, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + false, // ABlockLdsAddExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + false, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 55c884246f..2d00545515 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1,4 +1,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) -add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) \ No newline at end of file +add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index cb4f60764e..c33fe357d8 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -69,18 +69,21 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| -///###### RRR - ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; -///###### RCR - < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; + , S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) @@ -229,7 +232,7 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50, true, 50}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp new file mode 100644 index 0000000000..9a81ef5ea7 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -0,0 +1,396 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using B0DataType = FP8; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(F16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()(BF16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + // clang-format off + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, + 256, 256, 128, + 16, 16, + 16, 16, + 8, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + ck::index_t Warmup = 50; + ck::index_t Repeat = 50; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + } + else if(argc == 14) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + + Warmup = std::stoi(argv[12]); + Repeat = std::stoi(argv[13]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + printf("arg10 to 11: Warmup, Repeat\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_m_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + size_t total_size = + (M * K * sizeof(A0DataType) + N * K * sizeof(B0DataType) + M * sizeof(D0DataType) + + N * sizeof(D1DataType) + M * N * sizeof(EDataType)); + int rotate_buf_num = + ck::math::min(size_t(Repeat), ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size)); + + float ave_time = invoker.Run( + argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, true, rotate_buf_num}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp index fb1642bba5..cbbd37408e 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp @@ -36,9 +36,9 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using A0DataType = I8; using B0DataType = I8; using AccDataType = I32; -using CShuffleDataType = I32; -using D0DataType = F32; -using D1DataType = F32; +using CShuffleDataType = F16; +using D0DataType = F16; +using D1DataType = F16; using DsDataType = ck::Tuple; using EDataType = F16; @@ -74,6 +74,24 @@ struct MultiplyMultiply e = ck::type_convert(x0_f); } + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const ck::half_t x0_f = ck::type_convert(c) * d0 * d1; + + e = x0_f; + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const ck::half_t x0_f = c * d0 * d1; + + e = x0_f; + } + template <> __host__ __device__ constexpr void operator()( ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const @@ -91,7 +109,7 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // clang-format off @@ -102,7 +120,17 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ///###### RRR ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; ///###### RCR - < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; + < Row, Col, DsLayout, ELayout, + A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, + 64, 128, 256, + 16, 16, + 32, 32, + 1, 2, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>; // clang-format on int main(int argc, char* argv[]) @@ -196,6 +224,12 @@ int main(int argc, char* argv[]) d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-25, 25}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 25}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 200}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 200}); + break; default: a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); @@ -251,7 +285,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + hipStream_t stream; + hip_check_error(hipStreamCreate(&stream)); + + float ave_time = invoker.Run(argument, StreamConfig{stream, time_kernel, 0, 20, 50, true, 50}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 5b00b5a123..30f03cb53b 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -13,7 +13,9 @@ #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/sequence.hpp" + #include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index c05660c8ab..4c23250d05 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -492,6 +492,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue + # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] @@ -499,14 +500,15 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= dpad == dvpad if not cond: continue - if receipt == 3: + elif receipt == 3: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] cond &= dpad == dvpad cond &= deterministic == "f" if not cond: continue - if receipt == 4: + # PyTorch integration + elif receipt == 4: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'bias'] cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] @@ -514,6 +516,26 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= deterministic == "f" if not cond: continue + # Aiter (mha_bwd) integration + elif receipt == 10: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "t" + if not cond: + continue + # Aiter (mha_varlen_bwd) integration + elif receipt == 11: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "t" + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index ad8daba17e..b72627ed5d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -487,6 +487,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue + # 2 - Flash attention integration if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' @@ -494,13 +495,32 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_squant == 'f' if not cond: continue - if receipt == 4: + # PyTorch integration + elif receipt == 4: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'bias'] cond &= pipeline.F_squant == 'f' if not cond: continue + # Aiter(mha_fwd) integration + elif receipt == 10: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 11: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 2f20819302..f8a89448ba 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -326,7 +326,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue - if receipt == 2: + # 2 - Flash attention integration + # 12 - Aiter(mha_fwd_kvcache) integration + if receipt in (2, 12): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 37745dd382..c0ca666b11 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -268,7 +268,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - + // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; @@ -705,6 +705,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue + # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' @@ -712,6 +713,24 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 11: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_fwd_kvcache) integration + elif receipt == 12: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index a0fb42aa11..0c2cef1ce7 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -17,7 +17,7 @@ class HandlerId(IntEnum): LIST_BLOBS = 0 WRITE_BLOBS = 1 -# inspect all modules under 'codegen.ops' and register API handlers +# inspect all modules under 'codegen.ops' and register API handlers ops = [] for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) @@ -104,7 +104,11 @@ if __name__ == "__main__": help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ " 1: generate more instance to cover all hdim\n" + \ " 2: Only generate instance for Flash attention integration\n" + \ - " 4: Only generate instance for PyTorch integration" + " 4: Only generate instance for PyTorch integration\n" + \ + " 10: Only generate instance for Aiter(mha_fwd, mha_bwd) integration\n" + \ + " 11: Only generate instance for Aiter(mha_varlen_fwd, mha_varlen_bwd) integration\n" + \ + " 12: Only generate instance for Aiter(mha_fwd_kvcache) integration" + ) args = parser.parse_args() diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 636b34981f..dbc582e5a3 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -35,7 +35,7 @@ #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif -template +template struct GemmBasicTypeConfig; template <> @@ -75,6 +75,15 @@ struct GemmBasicTypeConfig using CDataType = ck_tile::half_t; }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; @@ -114,6 +123,12 @@ struct DataTypeTraits static constexpr const char* name = "bf8"; }; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 042ad372dc..f068cbc1da 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -29,6 +29,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K, // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } +template +void permute_tensor_b(Tensor& tensor) +{ + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = tensor(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 0, i) = i4x2; + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 2, i) = i4x2; + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 4, i) = i4x2; + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 6, i) = i4x2; + } + } + } +} template +template int run_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -94,10 +153,7 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; - using ADataType = typename GemmBasicTypeConfig::ADataType; - using BDataType = typename GemmBasicTypeConfig::BDataType; - using CDataType = typename GemmBasicTypeConfig::CDataType; - using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -107,10 +163,10 @@ int run_gemm_example_with_layouts(int argc, ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - ck_tile::index_t kbatch = arg_parser.get_int("split_k"); - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); - ck_tile::index_t init_method = arg_parser.get_int("init"); + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); @@ -123,16 +179,23 @@ int run_gemm_example_with_layouts(int argc, ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - if (init_method == 0) { + if(init_method == 0) + { ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); - } else if (init_method == 1) { + } + else if(init_method == 1) + { ck_tile::FillMonotonicSeq{}(a_m_k); ck_tile::FillMonotonicSeq{}(b_k_n); - } else if (init_method == 2) { + } + else if(init_method == 2) + { ck_tile::FillConstant{static_cast(1)}(a_m_k); ck_tile::FillConstant{static_cast(1)}(b_k_n); - } else { + } + else + { a_m_k.SetZero(); b_k_n.SetZero(); } @@ -142,7 +205,17 @@ int run_gemm_example_with_layouts(int argc, ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); + if constexpr(std::is_same_v) + { + // Permute data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + permute_tensor_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); @@ -188,6 +261,11 @@ int run_gemm_example_with_layouts(int argc, } else if(arg_parser.get_int("v") == 2) { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } ck_tile::HostTensor c_m_n_gpu_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); @@ -198,17 +276,18 @@ int run_gemm_example_with_layouts(int argc, BDataType* d_B; CDataType* d_C; - ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); - ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); - ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes())); + ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes())); + ck_tile::hip_check_error( + hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes())); ck_tile::hip_check_error(hipMemcpy(d_A, a_m_k_dev_buf.GetDeviceBuffer(), - M * K * sizeof(ADataType), + a_m_k.get_element_space_size_in_bytes(), hipMemcpyHostToDevice)); ck_tile::hip_check_error(hipMemcpy(d_B, b_k_n_dev_buf.GetDeviceBuffer(), - N * K * sizeof(BDataType), + b_k_n.get_element_space_size_in_bytes(), hipMemcpyHostToDevice)); ck_tile::reference_gemm_gpu(argc, argv, Row{}, Col{}, Row{}); } +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + else if(data_type == "pk_int4_t") + { + // TODO: Add support for bhalf_t ADataType + return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } +#endif else { throw std::runtime_error("Unsupported data_type!"); @@ -344,6 +353,15 @@ int run_gemm_example(int argc, char* argv[]) { return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); } +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + else if(data_type == "pk_int4_t") + { + // TODO: Add support for bhalf_t ADataType + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } +#endif else { throw std::runtime_error("Unsupported data_type!"); diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index 43bc9a6cfe..574edf64d3 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -3,7 +3,7 @@ #include // different threshold for different dtype -template +template auto get_elimit() { double rtol = 1e-2; @@ -39,6 +39,7 @@ auto create_args(int argc, char* argv[]) .insert("v", "1", "cpu validation or not") .insert("kname", "1", "print kernel name or not") .insert("prec", "fp16", "precision") + .insert("quant", "int8", "precision") .insert("warmup", "5", "cold iter") .insert("repeat", "20", "hot iter"); @@ -46,7 +47,7 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -template +template bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::index_t m = arg_parser.get_int("m"); @@ -54,16 +55,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t stride = arg_parser.get_int("stride"); if(stride < 0) stride = n; - float epsilon = arg_parser.get_float("e"); - std::string data_type = arg_parser.get_str("prec"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); + float epsilon = arg_parser.get_float("e"); + std::string input_data_type = arg_parser.get_str("prec"); + std::string quantized_data_type = arg_parser.get_str("quant"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); assert(stride >= n); - using TypeConfig = AddRmsnormRdquantTypeConfig; + using TypeConfig = AddRmsnormRdquantTypeConfig; using ADataType = typename TypeConfig::ADataType; using BDataType = typename TypeConfig::BDataType; @@ -102,10 +104,10 @@ bool run(const ck_tile::ArgParser& arg_parser) b_buf.ToDevice(b_host.data()); gamma_buf.ToDevice(gamma_host.data()); - std::cout << "[" << data_type << "]" + std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; - add_rmsnorm2d_rdquant_fwd_traits traits{data_type, SaveX}; + add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX}; add_rmsnorm2d_rdquant_fwd_args args{a_buf.GetDeviceBuffer(), b_buf.GetDeviceBuffer(), @@ -129,14 +131,14 @@ bool run(const ck_tile::ArgParser& arg_parser) num_byte += sizeof(XDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::endl; bool pass = true; if(do_validation) { using YDataType = ComputeDataType; - using InvRmsDataType = DataType; + using InvRmsDataType = InputDataType; // Add { @@ -144,28 +146,36 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::reference_binary_elementwise( a_host, b_host, x_host_ref, op); - x_buf.FromDevice(x_host_dev.data()); + if constexpr(SaveX) + { + x_buf.FromDevice(x_host_dev.data()); - auto [rtol, atol] = get_elimit(); - if(stride == n) - { - pass = ck_tile::check_err( - x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); - } - else - { - for(int i_r = 0; i_r < m; i_r++) + auto [rtol, atol] = get_elimit(); + if(stride == n) { - std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, - x_host_dev.begin() + i_r * stride + n); - std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, - x_host_ref.begin() + i_r * stride + n); - pass &= ck_tile::check_err(x_host_dev_row, - x_host_ref_row, - std::string("x[") + std::to_string(i_r) + - std::string("] Error: Incorrect results!"), - rtol, - atol); + pass = ck_tile::check_err(x_host_dev, + x_host_ref, + std::string("x Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } } } } @@ -256,23 +266,40 @@ int main(int argc, char* argv[]) if(!result) return -1; - const std::string data_type = arg_parser.get_str("prec"); - int save_x = arg_parser.get_int("save_x"); - if(data_type == "fp16" && save_x) + const std::string input_data_type = arg_parser.get_str("prec"); + const std::string quantized_data_type = arg_parser.get_str("quant"); + int save_x = arg_parser.get_int("save_x"); + if(input_data_type == "fp16" && quantized_data_type == "int8" && save_x) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "fp16" && !save_x) + else if(input_data_type == "fp16" && quantized_data_type == "int8" && !save_x) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && save_x) + else if(input_data_type == "bf16" && quantized_data_type == "int8" && save_x) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } - else if(data_type == "bf16" && !save_x) + else if(input_data_type == "bf16" && quantized_data_type == "int8" && !save_x) { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; + } + else if(input_data_type == "fp16" && quantized_data_type == "fp8" && save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(input_data_type == "fp16" && quantized_data_type == "fp8" && !save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(input_data_type == "bf16" && quantized_data_type == "fp8" && save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(input_data_type == "bf16" && quantized_data_type == "fp8" && !save_x) + { + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index 443b9b1024..c91b387d62 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -8,11 +8,11 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" #include -template +template struct AddRmsnormRdquantTypeConfig; template <> -struct AddRmsnormRdquantTypeConfig +struct AddRmsnormRdquantTypeConfig { using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; @@ -24,7 +24,7 @@ struct AddRmsnormRdquantTypeConfig }; template <> -struct AddRmsnormRdquantTypeConfig +struct AddRmsnormRdquantTypeConfig { using ADataType = ck_tile::bf16_t; using BDataType = ck_tile::bf16_t; @@ -35,13 +35,38 @@ struct AddRmsnormRdquantTypeConfig using ComputeDataType = float; }; +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using XDataType = ck_tile::half_t; + using YScaleDataType = float; + using QYDataType = ck_tile::fp8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using XDataType = ck_tile::bf16_t; + using YScaleDataType = float; + using QYDataType = ck_tile::fp8_t; + using ComputeDataType = float; +}; + // runtime args struct add_rmsnorm2d_rdquant_fwd_args : public ck_tile::AddRmsnorm2dRdquantFwdHostArgs { }; // this is used to pattern-match internl kernel implementation, not to instantiate kernel -template struct add_rmsnorm2d_rdquant_fwd_traits_ { - using DataType = ck_tile::remove_cvref_t; + using InputDataType = ck_tile::remove_cvref_t; + using QuantizedDataType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); @@ -114,7 +140,8 @@ float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_ // This is the public API, will be generated by script struct add_rmsnorm2d_rdquant_fwd_traits { - std::string data_type; + std::string input_data_type; + std::string quantized_data_type; bool save_x; }; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp index 966c5bd02f..f695ea30b2 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp @@ -4,7 +4,8 @@ #include #include "add_rmsnorm2d_rdquant_fwd.hpp" -template -using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; -template -float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/, +template +float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t, add_rmsnorm2d_rdquant_fwd_args a, const ck_tile::stream_config& s) { @@ -32,99 +34,145 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/, // clang-format off // rm rn tm tn vn pd x 3p if(a.n <= 64) { - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 128) { if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 256) { if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 512) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 768) { if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 1024) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 1536) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 2048) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 3072) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } else if(a.n <= 4096) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } - else if(a.n > 4096) { + else if(a.n <= 8192) { + if(a.n<8192){ + if(t.save_x){ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else{ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + } + else{ + if(t.save_x){ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else{ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + } + } + else if(a.n > 8192) { if (a.n % 8 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 4 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else if (a.n % 2 == 0) - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); else - r = add_rmsnorm2d_rdquant_fwd_>(s, a); + r = add_rmsnorm2d_rdquant_fwd_>(s, a); } return r; // clang-format on @@ -134,16 +182,45 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t, add_rmsnorm2d_rdquant_fwd_args a, const ck_tile::stream_config& s) { - - // Only support instance of save_x == true for now - assert(t.save_x); - if(t.data_type.compare("fp16") == 0) + if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 && + t.save_x) { - return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); } - else if(t.data_type.compare("bf16") == 0) + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 && + !t.save_x) { - return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); } else throw std::runtime_error("Without supported instances!"); diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp index 5495e3c9ab..00df2f5082 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp @@ -15,8 +15,12 @@ template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); #endif -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp index 8bbfdc8589..2adb54c078 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp @@ -6,8 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp index 381a11fc80..39089843a2 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp index 2fefac6934..ddb8e1b354 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp index 263713bbc7..2a87614403 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp @@ -6,9 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp index c62c596fab..045a3b8880 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp @@ -6,9 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp deleted file mode 100644 index e4951f6ab9..0000000000 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp index 4c7ee48e8e..1028973e74 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp @@ -6,8 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp index 8659dc82b3..b8439a0ce9 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp index 5f15f11b47..b24b245757 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp new file mode 100644 index 0000000000..14f0ec8525 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp @@ -0,0 +1,42 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp new file mode 100644 index 0000000000..3e3a6d75b9 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp index 8ffdacbdcd..04d735c12c 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp @@ -15,8 +15,12 @@ template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); #endif -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp index 3551099651..5893d6c3ee 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp @@ -6,8 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp index d4d0474c27..ec9c417bf3 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp @@ -6,9 +6,13 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp index 2cb300eda6..5bc8245106 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp index fb0ceb4c58..c022c62de6 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp @@ -6,9 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp index 3a241a3c93..19172b0793 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp @@ -6,9 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp deleted file mode 100644 index d3094679f9..0000000000 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp +++ /dev/null @@ -1,14 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" - -// clang-format off -// rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); - -// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp index 919bc177e8..f491d92787 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp @@ -6,8 +6,12 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp index 8a44f5e00f..065f0ea4cc 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp @@ -5,8 +5,11 @@ #include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" // clang-format off -// rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp index 5c4f05ec3c..be8c6c4de5 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp @@ -6,7 +6,10 @@ // clang-format off // rm rn tm tn vn pd x 3p -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); -template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); // clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp new file mode 100644 index 0000000000..ad2dfd931e --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp @@ -0,0 +1,41 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp new file mode 100644 index 0000000000..e3afa07fa4 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index 6baaad471a..25b10e1dc4 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -11,7 +11,8 @@ using S = ck_tile::stream_config; using A = add_rmsnorm2d_rdquant_fwd_args; -template -using trait_ = add_rmsnorm2d_rdquant_fwd_traits_ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) { - using DataType = typename Traits_::DataType; + using InputDataType = typename Traits_::InputDataType; + using QuantizedDataType = typename Traits_::QuantizedDataType; using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem< - typename AddRmsnormRdquantTypeConfig::ADataType, - typename AddRmsnormRdquantTypeConfig::BDataType, - typename AddRmsnormRdquantTypeConfig::GammaDataType, - typename AddRmsnormRdquantTypeConfig::ComputeDataType, - typename AddRmsnormRdquantTypeConfig::XDataType, - typename AddRmsnormRdquantTypeConfig::YScaleDataType, - typename AddRmsnormRdquantTypeConfig::QYDataType, + typename AddRmsnormRdquantTypeConfig::ADataType, + typename AddRmsnormRdquantTypeConfig::BDataType, + typename AddRmsnormRdquantTypeConfig::GammaDataType, + typename AddRmsnormRdquantTypeConfig::ComputeDataType, + typename AddRmsnormRdquantTypeConfig::XDataType, + typename AddRmsnormRdquantTypeConfig::YScaleDataType, + typename AddRmsnormRdquantTypeConfig::QYDataType, typename Traits_::Shape, Traits_::kPadN, Traits_::kSaveX, diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 1ec0c6bc23..c8d1c20f4c 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -4,8 +4,9 @@ #pragma once #include "ck/config.h" + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include "ck/utility/env.hpp" -#ifndef CK_CODE_GEN_RTC #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 05dc491af7..e04e27b761 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -3,6 +3,7 @@ #pragma once +#ifndef __HIPCC_RTC__ #include #include #include @@ -97,3 +98,4 @@ inline bool is_gfx12_supported() } } // namespace ck +#endif diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index c0894f1d70..0dfd275269 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -19,16 +19,15 @@ inline void hip_check_error(hipError_t x) } } -#define HIP_CHECK_ERROR(retval_or_funcall) \ - do \ - { \ - hipError_t _tmpVal = retval_or_funcall; \ - if(_tmpVal != hipSuccess) \ - { \ - std::ostringstream ostr; \ - ostr << "HIP Function Failed (" \ - << "hip_check_error.hpp" \ - << "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \ - throw std::runtime_error(ostr.str()); \ - } \ +#define HIP_CHECK_ERROR(retval_or_funcall) \ + do \ + { \ + hipError_t _tmpVal = retval_or_funcall; \ + if(_tmpVal != hipSuccess) \ + { \ + std::ostringstream ostr; \ + ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \ + << hipGetErrorString(_tmpVal); \ + throw std::runtime_error(ostr.str()); \ + } \ } while(0) diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 962f89e479..5c1c1c4e60 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#ifndef __HIPCC_RTC__ #include #include "ck/ck.hpp" @@ -166,3 +166,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, return 0; #endif } +#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp index 1ab460fa8a..5997b47790 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp @@ -9,13 +9,6 @@ namespace ck { -enum struct BlockGemmPipelineVersion -{ - v1, // Naive - v2, // Mem - v3, // Comp -}; - template +constexpr auto BlockGemmBPreshufflePipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp new file mode 100644 index 0000000000..8ed25895b5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -0,0 +1,506 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp new file mode 100644 index 0000000000..4c019a41a4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp @@ -0,0 +1,558 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 3 +// LocalPreFillStages: 2 +// LocalPreFetchStages: 2 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 2; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + // B global + A local + static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A + }); + + static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + // static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + // ignore = i; + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + // }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1, B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_bufs(I0)); + }); + }); + + // Local prefill A2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1); + + // // Global prefetch A3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + block_sync_lds(); + + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_bufs(local_read_buf)); + }); + }); + + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf.At(mfma_reg_buf), mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 3)); + } + // tail + + auto ReadWriteCompFunc = [&](auto mfma_reg, auto local_read_reg) { + block_sync_lds(); + + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_reg)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_bufs(local_read_reg)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(mfma_reg), mfma_reg); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + auto ReadCompFunc = [&](auto mfma_reg, auto local_read_reg) { + block_sync_lds(); + + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_reg)); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_bufs(local_read_reg)); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + auto CompFunc = [&](auto mfma_reg) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::Even) + { + ReadCompFunc(I0, I1); + CompFunc(I1); + } + else if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I0, I1); + ReadCompFunc(I1, I0); + CompFunc(I0); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp new file mode 100644 index 0000000000..49af782132 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -0,0 +1,860 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + __device__ static constexpr auto HotLoopScheduler(Stage stage) + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + if constexpr(stage.value == 0) + { + constexpr auto staged_num_buffer_load_b_per_ds_read_a = + num_buffer_load_inst_b / staged_num_ds_read_inst_a; + constexpr auto staged_num_mfma_per_buffer_load_b = + staged_num_mfma / num_buffer_load_inst_b; + // B global + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) { + ignore = ibuf_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(stage.value == 1) + { + constexpr auto staged_num_mfma_per_ds_write_a = + math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); + + constexpr auto stage_more_mfma = + staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; + + // A local write + static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { + if constexpr(i_inst.value < stage_more_mfma) + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + else + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(stage.value == 2) + { + constexpr auto staged_num_mfma_per_buffer_load_a = + math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a); + + constexpr auto stage_more_mfma = + staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a; + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) { + if constexpr(i_inst.value < stage_more_mfma) + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + } + else + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + } + }); + + __builtin_amdgcn_sched_barrier(0); + } + else + { + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + } + + template + __device__ static constexpr auto EpilogueScheduler_1(Stage stage) + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + if constexpr(stage.value == 0) + { + constexpr auto staged_num_buffer_load_b_per_ds_read_a = + num_buffer_load_inst_b / staged_num_ds_read_inst_a; + constexpr auto staged_num_mfma_per_buffer_load_b = + staged_num_mfma / num_buffer_load_inst_b; + // B global + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { + ignore = ibuf_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(stage.value == 1) + { +#if 0 + constexpr auto staged_num_ds_write_a_per_ds_read_a = + num_ds_write_inst_a / staged_num_ds_read_inst_a; + constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a; + // A local write + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) { + ignore = idswrite_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + }); + + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); +#elif 1 + constexpr auto staged_num_mfma_per_ds_write_a = + math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); + + constexpr auto stage_more_mfma = + staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; + + // A local write + static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { + if constexpr(i_inst.value < stage_more_mfma) + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + else + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + }); +#endif + __builtin_amdgcn_sched_barrier(0); + } + else + { + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + } + + __device__ static constexpr auto EpilogueScheduler_2() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(I0, I0, I0, k0, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(I0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + if constexpr(m0.value == 0) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + } + else if constexpr(m0.value == 1) + { + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + } + else if constexpr(m0.value == 2) + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + } + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == MRepeat - 1) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + I0), + a_thread_buf); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + I0), + a_thread_buf); + }); + } + + HotLoopScheduler(m0); + }); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + if constexpr(m0.value == 0) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + } + else if constexpr(m0.value == MRepeat - 1) + { + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + } + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == MRepeat - 1) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), + a_thread_buf); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), + a_thread_buf); + }); + } + + EpilogueScheduler_1(m0); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value != (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number{}, I0, I0, k0, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), + a_thread_buf); + }); + + EpilogueScheduler_2(); + } + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value != (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number{}, I0, I0, k0, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), + a_thread_buf); + }); + + EpilogueScheduler_2(); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp index ea0c511da3..48fe5131d2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp @@ -11,15 +11,6 @@ namespace ck { -enum struct BlockGemmPipelineVersion -{ - v1, // Naive - v2, // Mem - v3, // Comp - v4, // Comp, double lds buffer - v5, // Comp, double global prefetch register buffer -}; - template - __device__ static constexpr void HotLoopScheduler(ScheduleGroup schedule_group) + __device__ static constexpr void HotLoopScheduler() { // TODO: Take data type into consideration as pipe ver 3 // A-B splited schedule @@ -195,42 +194,42 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto idsread) { ignore = idsread; - __builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_issue - num_dsread_per_issue_a - num_dswrite_per_issue_a, - schedule_group); // MFMA + 0); // MFMA }); static_for<0, num_issue_b, 1>{}([&](auto i) { ignore = i; static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) { ignore = idsread; - __builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { ignore = idswrite; - __builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); - __builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_issue - num_dsread_per_issue_a - num_dswrite_per_issue_b, - schedule_group); // MFMA + 0); // MFMA }); __builtin_amdgcn_sched_barrier(0); } @@ -274,26 +273,15 @@ struct BlockwiseGemmXdlops_pipeline_v4{}> b_thread_bufs; // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Global prefetch 2 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0), I0); - - // Local prefill 2 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1), I1); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); // Local prefetch 1 block_sync_lds(); @@ -316,16 +304,20 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k) { @@ -367,13 +357,11 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k) { @@ -447,8 +433,8 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -478,13 +464,10 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k) { @@ -534,7 +517,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 + __device__ constexpr auto GetSrcThreadScratchIdx() + { + return threadwise_transfer_.template GetSrcThreadScratchIdx(); + } + template __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 774982d905..9285211519 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -3,11 +3,12 @@ #pragma once -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include #include #include + #include "ck/stream_config.hpp" #endif @@ -15,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #define GET_OBJECT_NAME_IMLP \ std::optional GetObjectName() const override \ { \ @@ -77,7 +78,7 @@ struct BaseOperator BaseOperator() = default; BaseOperator(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default; -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual std::string GetTypeString() const { return ""; } diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp index 09259224e7..204b09cad4 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp @@ -2,9 +2,10 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#ifndef __HIPCC_RTC__ #include #include +#endif #include "device_base.hpp" @@ -28,6 +29,7 @@ template // TODO: enum for mask type struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator { +#ifndef __HIPCC_RTC__ virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b0, @@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator CElementwiseOperation c_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 48fca67f56..3c79b92ec8 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -2,9 +2,11 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#ifndef __HIPCC_RTC__ #include +#endif +#include "ck/utility/array.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { @@ -34,6 +36,7 @@ struct DeviceGemmMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); +#ifndef __HIPCC_RTC__ virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, @@ -51,6 +54,7 @@ struct DeviceGemmMultipleD : public BaseOperator CDEElementwiseOperation cde_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif }; // GEMM: @@ -76,6 +80,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); +#ifndef __HIPCC_RTC__ virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, @@ -94,6 +99,52 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator CDEElementwiseOperation cde_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif +}; + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index 997dcb75a6..8824f44ec5 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -28,8 +28,7 @@ enum struct GemmSpecialization NKOPadding, MNKOPadding, }; - -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) inline std::string getGemmSpecializationString(const GemmSpecialization& s) { switch(s) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index ea5a5d0e16..b4ab96d397 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -3,8 +3,12 @@ #pragma once +#ifndef __HIPCC_RTC__ #include #include +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -15,8 +19,6 @@ #include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { @@ -429,6 +431,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle matrix_padder.PadN, MaskOutUpperTriangle>; +#ifndef __HIPCC_RTC__ // Argument struct Argument : public BaseArgument { @@ -603,6 +606,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return Run(*dynamic_cast(p_arg), stream_config); } }; +#endif static constexpr bool IsValidCompilationParameter() { @@ -610,6 +614,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } +#ifndef __HIPCC_RTC__ static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) { @@ -837,6 +842,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return str.str(); } +#endif template struct Descriptor diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index e6466a487b..3fae3a3765 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -3,8 +3,12 @@ #pragma once +#ifndef __HIPCC_RTC__ #include #include +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -14,8 +18,6 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" namespace ck { @@ -224,9 +226,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD& MRaws, - const std::array& NRaws, - const std::array& DsStride) + static auto MakeDsGridDescriptor_M_N(const Array& MRaws, + const Array& NRaws, + const Array& DsStride) { return generate_tuple( [&](auto i) { @@ -308,6 +310,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; +#ifndef __HIPCC_RTC__ // Argument struct Argument : public BaseArgument { @@ -497,6 +500,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD LoopSchedToString{ - {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + std::map LoopSchedToString{{LoopScheduler::Default, "Default"}, + { LoopScheduler::Interwave, + "Interwave" }}; std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, - {PipelineVersion::v2, "v2"}}; + { PipelineVersion::v2, + "v2" }}; // clang-format off str << "DeviceGemmMultipleD_Xdl_CShuffle" @@ -708,6 +716,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD struct Descriptor @@ -846,7 +855,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 7661114ef8..0796614bd4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -227,8 +227,20 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + : public DeviceGemmMultipleDSplitKBPreShuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize / + 4 * (1 + GridwiseGemm::NWave); + constexpr auto estimated_reg_b = + NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2); + constexpr auto estimated_reg_c = + MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4; + constexpr auto estimated_reg_total = + estimated_reg_a + estimated_reg_b + estimated_reg_c; + + constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; + + // static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + // has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right + // now"); + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + throw std::runtime_error("todo: only v1 v2 and v3 support now"); + } + } +#if 0 + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } +#endif + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + throw std::runtime_error("todo: only v3 support now"); + } + } +#endif + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA, + index_t MaxTransposeTransferSrcScalarPerVector = 1, + index_t MaxTransposeTransferDstScalarPerVector = 1> struct DeviceGroupedConvBwdWeight_Xdl_CShuffle : public DeviceGroupedConvBwdWeight() || + is_NGCDHW_GKZYXC_NGKDHW()) || + is_same_v); + using AElementwiseOperation = OutElementwiseOperation; using BElementwiseOperation = InElementwiseOperation; using CElementwiseOperation = WeiElementwiseOperation; @@ -279,6 +288,51 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + static constexpr index_t ClusterLengthMPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + static constexpr index_t TransposeTransferSrcScalarPerVectorAligned = + std::min(NPerBlock / ClusterLengthNPerBlock, MaxTransposeTransferSrcScalarPerVector); + static constexpr index_t TransposeTransferDstScalarPerVectorAligned = + std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferDstScalarPerVector); + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + + using GridwiseElementwiseTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, @@ -398,6 +452,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); + std::array b_g_n_c_wis_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); + std::array a_g_n_k_wos_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths, + a_g_n_k_wos_strides); + const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -407,9 +468,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle input_spatial_lengths_, filter_spatial_lengths_, output_spatial_lengths_, - b_g_n_c_wis_strides, + b_g_n_c_wis_strides_transposed, e_g_k_c_xs_strides, - a_g_n_k_wos_strides, + a_g_n_k_wos_strides_transposed, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -424,8 +485,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = Conv_K_ * Conv_C_ * std::accumulate(begin(filter_spatial_lengths_), @@ -441,6 +502,54 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); } + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceSizeBytes() const + { + // Transpose require workspace for A and B + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes(); + } + else + { + return 0; + } } const ADataType* p_a_grid_; @@ -453,6 +562,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle Block2CTileMap block_2_ctile_map_; + Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_b_; + + NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; + // for computing batch offset ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; @@ -502,13 +617,57 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + float avg_time = 0.f; + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + const index_t grid_size_a = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t grid_size_b = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); + + p_a_grid = type_convert(arg.p_workspace_); + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + ADataType* p_out_a_grid = type_convert(arg.p_workspace_); + BDataType* p_out_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + + // Different data type for A and B is not supported + auto kernel_transpose = kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size_a + grid_size_b), + dim3(BlockSize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_out_a_grid), + make_tuple(p_out_b_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + grid_size_a); } const index_t grid_size = @@ -536,33 +695,35 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ComputePtrOffsetOfStridedBatch<>, has_main_loop>; - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.Conv_G_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_, - arg.compute_ptr_offset_of_batch_); + avg_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); }; if(has_main_k0_block_loop) { - return launch_kernel(integral_constant{}); + launch_kernel(integral_constant{}); } else { - return launch_kernel(integral_constant{}); + launch_kernel(integral_constant{}); } + return avg_time; } float Run(const BaseArgument* p_arg, @@ -598,7 +759,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle else if constexpr(NDimSpatial == 2) { if constexpr(!(is_NHWGC_GKYXC_NHWGK() || - is_GNHWC_GKYXC_GNHWK())) + is_GNHWC_GKYXC_GNHWK() || + is_NGCHW_GKYXC_NGKHW())) { return false; } @@ -606,7 +768,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle else if constexpr(NDimSpatial == 3) { if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || - is_GNDHWC_GKZYXC_GNDHWK())) + is_GNDHWC_GKZYXC_GNDHWK() || + is_NGCDHW_GKZYXC_NGKDHW())) { return false; } @@ -644,6 +807,35 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle return false; } + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVectorAligned != 0) + { + return false; + } + + if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVectorAligned != 0) + { + return false; + } + + const index_t input_spatial_acum = ck::accumulate_n( + arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % TransposeTransferSrcScalarPerVectorAligned != 0) + { + return false; + } + + if(output_spatial_acum % TransposeTransferSrcScalarPerVectorAligned != 0) + { + return false; + } + } + // Gridwise GEMM size return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, @@ -764,12 +956,49 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle << BBlockTransferDstScalarPerVector_K1 << ", " << CShuffleMXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle << ", " - << CBlockTransferScalarPerVector_NWaveNPerXdl - << ">"; + << CBlockTransferScalarPerVector_NWaveNPerXdl; + + if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) { + str << ", TransposeTransferSrcScalarPerVectorAligned: " + << TransposeTransferSrcScalarPerVectorAligned <<", " + << "TransposeTransferDstScalarPerVectorAligned: " << TransposeTransferDstScalarPerVectorAligned; + } + + + str << ">"; // clang-format on return str.str(); } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument structure!"); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..f6ea23a1e7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -0,0 +1,1349 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__) +} + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 + : public DeviceGroupedConvBwdWeight +{ + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + + using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static inline auto I0 = Number<0>{}; + static inline auto I1 = Number<1>{}; + static inline auto I2 = Number<2>{}; + static inline auto I3 = Number<3>{}; + static inline auto I4 = Number<4>{}; + static inline auto I5 = Number<5>{}; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvBwdWeightToGemmV2{}; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = + GridwiseGemm_xdl_cshuffle_v3; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, p_b_grid, arg.p_c_grid_, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_); + + float ave_time = 0; + + index_t k_grain = gemm_arg.KBatch * K0PerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * K0PerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && + arg.k_batch_ > 1) + { + return false; + } + + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWC_GKXC_GNWK()) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index 0ec55984bc..9fe2f0d976 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -13,6 +13,7 @@ enum struct MaskingSpecialization MaskOutUpperTriangle }; +#ifndef __HIPCC_RTC__ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) { switch(s) @@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s default: return "Unrecognized specialization!"; } } +#endif struct MaskDisabledPredicate { @@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate template struct C0MatrixMask_impl { - __host__ __device__ C0MatrixMask_impl(index_t NRaw) + __host__ __device__ constexpr C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) { } diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index 85adb64b43..e836e73a1d 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -21,6 +21,12 @@ struct ColumnMajor : public BaseTensorLayout { static constexpr const char* name = "ColumnMajor"; }; + +struct MFMA : public BaseTensorLayout +{ + static constexpr const char* name = "MFMA"; +}; + } // namespace gemm namespace convolution { @@ -430,7 +436,7 @@ struct G_NDHW : public BaseTensorLayout } // namespace convolution -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template < typename Layout, typename std::enable_if::value, bool>::type = false> diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 370d03258d..b57ae22172 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -283,6 +283,16 @@ struct MultiplyMultiply e = ck::type_convert(x0_f); } + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + template <> __host__ __device__ constexpr void operator()( ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index f1055d1eff..f1d0f9844d 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -284,6 +284,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(half_t& y, const int32_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { @@ -691,7 +697,7 @@ struct FastGelu template __device__ void operator()(Y& y, const X& x) const; -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template <> __host__ void operator()(float& y, const float& x) const { @@ -703,7 +709,6 @@ struct FastGelu y = x / (1.f + emu); } #endif - // device code, use lower precision "__ocml_exp_f32" and "rcp" template <> __device__ void operator()(float& y, const float& x) const diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 2bc9ef87ac..64fad1ca48 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -8,7 +8,7 @@ #include "ck/utility/tuple.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include #endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index eb1eb533d7..060f6d5d15 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -473,7 +473,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); } -#ifdef CK_CODE_GEN_RTC +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) template __host__ __device__ static auto MakeDsGridDescriptor_M_N(const ck::Array& MRaws, @@ -486,6 +486,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const std::array& NRaws, const std::array& DsStride) #endif + { return generate_tuple( [&](auto i) { @@ -949,7 +950,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t K, const index_t StrideA, const index_t StrideB, -#ifdef CK_CODE_GEN_RTC +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) const ck::Array StrideDs, #else const std::array StrideDs, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 9dad66913a..f8de0a48e5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -2,7 +2,8 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#ifndef CK_CODE_GEN_RTC + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include #endif @@ -54,7 +55,7 @@ constexpr auto GridwiseGemmPipeline_Selector() } else { -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; #endif } @@ -62,7 +63,7 @@ constexpr auto GridwiseGemmPipeline_Selector() } // namespace ck -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) { switch(p) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index e04f24c989..fcb12f4a14 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -230,6 +230,23 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }(); + // Pad both M and K to be multiples of the block sizes + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; +#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MKPadding || @@ -296,6 +313,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 return a_grid_desc_ak0_m_ak1; } +#endif } __device__ static auto MakeBGridDescriptor_BK0_N_BK1( @@ -312,6 +330,23 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }(); + // Pad both N and K to be multiples of the block sizes + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; +#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::NKPadding || @@ -378,6 +413,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 return b_grid_desc_bk0_n_bk1; } +#endif } template @@ -412,6 +448,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }(); + // Pad both M and N to be multiples of the block sizes + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MNPadding || @@ -449,6 +492,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 // not pad M or N return c_grid_desc_mraw_nraw; } +#endif } struct Problem @@ -953,7 +997,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) { if(!(karg.M % MPerBlock == 0)) { @@ -970,7 +1015,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) { if(!(karg.N % NPerBlock == 0)) { @@ -1036,6 +1082,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } } @@ -1051,6 +1098,10 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; return false; } } @@ -1065,6 +1116,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } } @@ -1082,6 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } } @@ -1098,17 +1151,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } - return false; - } - } - if constexpr(is_same, bhalf_t>::value) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << " Grid size: " << karg.Grid_size << " > 1 is not support yet" - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; + return false; } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp old mode 100644 new mode 100755 index 55360fc0d0..d4c915aa5e --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -674,11 +674,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address return make_naive_tensor_descriptor( make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + make_tuple(Number{} * AK1Number, AK1Number, I1)); } // xor tensor transformation request more unnecessary vgpr usage, would cause register spill // in some cases. @@ -810,11 +812,13 @@ struct GridwiseGemm_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN) + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + make_tuple(Number{} * BK1Number, BK1Number, I1)); } else if constexpr(is_same::value) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index e5a31f8d1f..a9e73bf461 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -676,11 +676,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address return make_naive_tensor_descriptor( make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + make_tuple(Number{} * AK1Number, AK1Number, I1)); } // xor tensor transformation request more unnecessary vgpr usage, would cause register spill // in some cases. @@ -813,11 +815,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN) + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + make_tuple(Number{} * BK1Number, BK1Number, I1)); } else if constexpr(is_same::value) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp new file mode 100644 index 0000000000..238ab14606 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -0,0 +1,1968 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + static constexpr auto BlockSizeNumber = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KLane = + mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), + make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead * NLane; + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(LDSTypeA), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline +#if 1 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run_2Lds( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + p_shared1, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + 2>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index baf14b2575..7ccea96dda 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -352,6 +352,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } } + template + __device__ constexpr auto + GetSrcThreadScratchIdx(Number thread_scratch_id = Number{}) + { + using vector_t = typename vector_type_maker::type::type; + return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType(SeqIdx{}); + } + template __device__ void TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 4f20487b9b..8c0b950941 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -780,7 +780,6 @@ struct mfma_type } }; -// TODO: fix mfma...f8f6f4 instructions template <> struct mfma_type { @@ -847,9 +846,14 @@ struct mfma_type // clang-format on template - __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + __device__ void run(const FloatA& a, + const int32_t scale_a, + const FloatB& b, + const int32_t scale_b, + FloatC& reg_c) const { - intrin_mfma_scale_f32_32x32x64f8f6f4::Run(a, b, reg_c); + intrin_mfma_scale_f32_32x32x64f8f6f4::Run( + a, scale_a, b, scale_b, reg_c); } }; @@ -871,9 +875,14 @@ struct mfma_type // clang-format on template - __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + __device__ void run(const FloatA& a, + const int32_t scale_a, + const FloatB& b, + const int32_t scale_b, + FloatC& reg_c) const { - intrin_mfma_scale_f32_16x16x128f8f6f4::Run(a, b, reg_c); + intrin_mfma_scale_f32_16x16x128f8f6f4::Run( + a, scale_a, b, scale_b, reg_c); } }; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index bc290d5641..f34e0e59b3 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,6 +34,94 @@ struct TransformConvBwdWeightToGemmV2 static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[3]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Wo, NumGroupsToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t WiStride = input_strides[3]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Wi, NumGroupsToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Wi, NumGroupsToMerge, C), + make_tuple(NStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[3]; + const auto BatchStride = weights_strides[0]; + // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder + // for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K, X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pass_through_transform(X), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(X, NumGroupsToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + template ::type = false> constexpr static auto make_out_grid_desc(const index_t N, @@ -221,6 +309,187 @@ struct TransformConvBwdWeightToGemmV2 make_tuple(Sequence<0>{}, Sequence<1>{})); } + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Wi = input_spatial_lengths[0]; + + const index_t Wo = output_spatial_lengths[0]; + + const index_t X = filter_spatial_lengths[0]; + + const index_t ConvStrideW = conv_filter_strides[0]; + + const index_t ConvDilationW = conv_filter_dilations[0]; + + const index_t InLeftPadW = input_left_pads[0]; + + const index_t InRightPadW = input_right_pads[0]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * X * NumGroupsToMerge; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, NumGroupsToMerge, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3, 4>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + + } // function end + template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( const index_t N, diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 328e37d009..317f324e6d 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1008,6 +1008,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); +#ifndef __HIPCC_RTC__ template __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, const index_t global_offset, @@ -1059,5 +1060,6 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); #endif } +#endif } // namespace ck diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 128c8e9a2c..3604712837 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -7,7 +7,7 @@ #include "ck/utility/functional2.hpp" #include "ck/utility/math.hpp" -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include #include diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index b125e3adf6..010b7aabd3 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -533,9 +533,9 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> reg_c.template AsType()[Number<0>{}], 0, // cbsz 0, // blgp - 0, // { OPSEL_HI[0], OPSEL[0] }? + 0, // OPSEL scale_a, - 0, // { OPSEL_HI[1], OPSEL[1] }? + 0, // OPSEL scale_b); #else ignore = reg_a; @@ -569,9 +569,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> reg_c.template AsType()[Number<0>{}], 0, // cbsz 0, // blgp - 0, // { OPSEL_HI[0], OPSEL[0] }? + 0, // OPSEL scale_a, - 0, // { OPSEL_HI[1], OPSEL[1] }? + 0, // OPSEL scale_b); #else ignore = reg_a; diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 86dcb6c157..574be5c4ad 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -8,6 +8,19 @@ namespace ck { +enum struct BlockGemmPipelineVersion +{ + // For GEMM + v1, // Naive + v2, // Mem + v3, // Comp + v4, // Comp, double lds buffer + v5, // Comp, double global prefetch register buffer + + // For GEMM with preshuffled weight + // v1, single lds buffer + // v2, double lds buffer +}; enum struct BlockGemmPipelineScheduler { Intrawave, diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index ee72b9bcd9..872d48fee0 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -6,16 +6,20 @@ #include "ck/utility/amd_ck_fp8.hpp" #include "ck/utility/e8m0.hpp" #include "ck/utility/statically_indexed_array.hpp" -#ifdef CK_CODE_GEN_RTC + +/// Definitions from , conflict with +/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h. + +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) using int8_t = signed char; using uint8_t = unsigned char; using int16_t = signed short; using uint16_t = unsigned short; using float_t = float; -#endif -namespace ck { +#endif // __HIPCC_RTC__ -#ifdef CK_CODE_GEN_RTC +namespace ck { +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) using byte = unsigned char; #else using std::byte; @@ -2612,7 +2616,7 @@ using pk_i4x2_t = typename vector_type::type; using pk_i4x4_t = typename vector_type::type; using pk_i4x8_t = typename vector_type::type; -#ifdef CK_CODE_GEN_RTC +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) template struct NumericLimits; @@ -2825,6 +2829,118 @@ struct NumericLimits return bit_cast(binary_qnan); } }; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 + static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 + static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 + static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 + static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 + static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 + static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 + static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 + + __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } + __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } + __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_135() + { + return e8m0_bexp_t(binary_135); + } + __host__ __device__ static constexpr e8m0_bexp_t Binary_142() + { + return e8m0_bexp_t(binary_142); + } +}; #else template struct NumericLimits @@ -2959,7 +3075,6 @@ struct NumericLimits return bit_cast(binary_qnan); } }; -#endif template <> struct NumericLimits @@ -3072,6 +3187,7 @@ struct NumericLimits return e8m0_bexp_t(binary_142); } }; +#endif template struct NumericUtils diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index 6ba63fc761..9d5403ceb2 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -4,15 +4,7 @@ #pragma once namespace ck { - -#ifndef CK_CODE_GEN_RTC -template -using enable_if = std::enable_if; - -template -using enable_if_t = typename std::enable_if::type; - -#else +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) template struct enable_if { @@ -26,6 +18,12 @@ struct enable_if template using enable_if_t = typename enable_if::type; -#endif +#else +template +using enable_if = std::enable_if; + +template +using enable_if_t = typename std::enable_if::type; +#endif } // namespace ck diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp index 837ff66312..cbbce85007 100644 --- a/include/ck/utility/loop_scheduler.hpp +++ b/include/ck/utility/loop_scheduler.hpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_CODE_GEN_RTC +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #endif -#pragma once - #include "ck/utility/common_header.hpp" namespace ck { @@ -28,7 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler() } // namespace ck -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) { switch(s) diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 03eb7c646d..05ae9093e2 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/ck.hpp" +#include "data_type.hpp" #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" @@ -34,7 +35,7 @@ struct MagicDivision // WARNING: magic division is only applicable for division inside this range. // You should use the return value of CalculateMagicNumbers, if division is not inside this // range. The "else" logic below is to quiet down run-time error. - if(divisor >= 1 && divisor <= INT32_MAX) + if(divisor >= 1 && divisor <= ck::NumericLimits::Max()) { uint32_t shift = 0; for(shift = 0; shift < 32; ++shift) diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index b31b46fb5f..e235f51c93 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float); #endif // math functions for the host, some are implemented by calling C++ std functions -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); }; @@ -924,5 +924,23 @@ inline __device__ double expm1(double x) return expm1(x); }; +template +inline __device__ T cos(T x) +{ + return ck::type_convert(cosf(ck::type_convert(x))); +}; + +template <> +inline __device__ float cos(float x) +{ + return cosf(x); +}; + +template <> +inline __device__ double cos(double x) +{ + return cos(x); +}; + } // namespace math } // namespace ck diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 6061d48118..25dae4e335 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -3,7 +3,7 @@ #pragma once -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #endif @@ -902,7 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen::type; } // namespace ck -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template std::ostream& operator<<(std::ostream& os, const ck::Sequence) { diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 596c748a2a..b4f1545aa9 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -159,7 +159,7 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple& tuple) } } -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template using is_tuple = decltype(ck::declval().IsTuple()); #endif @@ -167,7 +167,7 @@ using is_tuple = decltype(ck::declval().IsTuple()); template __host__ __device__ constexpr auto IsNestedTuple(const Tuple&) { -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) return (is_detected::value || ...); #endif } diff --git a/include/ck/utility/type.hpp b/include/ck/utility/type.hpp index ef9326ae57..bde9c179ce 100644 --- a/include/ck/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -1,316 +1,313 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/ck.hpp" -#include "ck/utility/enable_if.hpp" -#include "ck/utility/integral_constant.hpp" - -namespace ck { -#ifdef CK_CODE_GEN_RTC -// NOLINTNEXTLINE -#define CK_BUILTIN_TYPE_TRAIT1(name) \ - template \ - struct name : bool_constant<__##name(T)> \ - { \ - } - -// NOLINTNEXTLINE -#define CK_BUILTIN_TYPE_TRAIT2(name) \ - template \ - struct name : bool_constant<__##name(T, U)> \ - { \ - } - -// NOLINTNEXTLINE -#define CK_BUILTIN_TYPE_TRAITN(name) \ - template \ - struct name : bool_constant<__##name(Ts...)> \ - { \ - } - -CK_BUILTIN_TYPE_TRAIT1(is_class); -CK_BUILTIN_TYPE_TRAIT1(is_pointer); -CK_BUILTIN_TYPE_TRAIT1(is_reference); -CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable); -CK_BUILTIN_TYPE_TRAIT1(is_unsigned); -CK_BUILTIN_TYPE_TRAIT2(is_base_of); - -template -struct remove_cv -{ - using type = T; -}; - -template -struct remove_cv : remove_cv -{ -}; - -template -struct remove_cv : remove_cv -{ -}; - -template -struct remove_reference -{ - typedef T type; -}; -template -struct remove_reference -{ - typedef T type; -}; -template -struct remove_reference -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; -template -struct remove_pointer -{ - typedef T type; -}; - -template -constexpr T&& forward(typename remove_reference::type& t_) noexcept -{ - return static_cast(t_); -} -template -constexpr T&& forward(typename remove_reference::type&& t_) noexcept -{ - return static_cast(t_); -} - -template -struct is_const : public integral_constant -{ -}; -template -struct is_const : public integral_constant -{ -}; -template -inline constexpr bool is_const_v = is_const::value; - -template -inline constexpr bool is_reference_v = is_reference::value; - -template -struct remove_const -{ - typedef T type; -}; -template -struct remove_const -{ - typedef T type; -}; -template -using remove_const_t = typename remove_const::type; -template -inline constexpr bool is_class_v = is_class::value; - -template -inline constexpr bool is_trivially_copyable_v = is_trivially_copyable::value; -// template -// T&& declval() noexcept; - -template -U private_declval(int); - -template -T private_declval(long); - -template -auto declval() noexcept -> decltype(private_declval(0)); - -template -using void_t = void; -#else -#include -#include -using std::declval; -using std::forward; -using std::is_base_of; -using std::is_class; -using std::is_class_v; -using std::is_const_v; -using std::is_pointer; -using std::is_reference; -using std::is_reference_v; -using std::is_trivially_copyable; -using std::is_trivially_copyable_v; -using std::is_unsigned; -using std::remove_const_t; -using std::remove_cv; -using std::remove_pointer; -using std::remove_reference; -using std::void_t; -#endif - -template -struct is_same : public integral_constant -{ -}; - -template -struct is_same : public integral_constant -{ -}; - -template -struct is_floating_point : public integral_constant -{ -}; - -template <> -struct is_floating_point : public integral_constant -{ -}; - -template <> -struct is_floating_point : public integral_constant -{ -}; -template <> -struct is_floating_point : public integral_constant -{ -}; - -template -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template <> -struct is_integral : public integral_constant -{ -}; - -template -inline constexpr bool is_same_v = is_same::value; - -template -inline constexpr bool is_base_of_v = is_base_of::value; - -template -inline constexpr bool is_unsigned_v = is_unsigned::value; - -template -using remove_reference_t = typename remove_reference::type; - -template -using remove_reference_t = typename remove_reference::type; - -template -using remove_cv_t = typename remove_cv::type; -template -using remove_cvref_t = remove_cv_t>; - -template -using remove_pointer_t = typename remove_pointer::type; - -template -inline constexpr bool is_pointer_v = is_pointer::value; - -template ::type = false> -__host__ __device__ constexpr Y bit_cast(const X& x) -{ - static_assert(__has_builtin(__builtin_bit_cast), ""); - static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type"); - - return __builtin_bit_cast(Y, x); -} -} // namespace ck +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/enable_if.hpp" +#include "ck/utility/integral_constant.hpp" + +namespace ck { +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) +// NOLINTNEXTLINE +#define CK_BUILTIN_TYPE_TRAIT1(name) \ + template \ + struct name : bool_constant<__##name(T)> \ + { \ + } + +// NOLINTNEXTLINE +#define CK_BUILTIN_TYPE_TRAIT2(name) \ + template \ + struct name : bool_constant<__##name(T, U)> \ + { \ + } + +// NOLINTNEXTLINE +#define CK_BUILTIN_TYPE_TRAITN(name) \ + template \ + struct name : bool_constant<__##name(Ts...)> \ + { \ + } + +CK_BUILTIN_TYPE_TRAIT1(is_class); +CK_BUILTIN_TYPE_TRAIT1(is_pointer); +CK_BUILTIN_TYPE_TRAIT1(is_reference); +CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable); +CK_BUILTIN_TYPE_TRAIT1(is_unsigned); +CK_BUILTIN_TYPE_TRAIT2(is_base_of); + +template +struct remove_cv +{ + using type = T; +}; + +template +struct remove_cv : remove_cv +{ +}; + +template +struct remove_cv : remove_cv +{ +}; + +template +struct remove_reference +{ + typedef T type; +}; +template +struct remove_reference +{ + typedef T type; +}; +template +struct remove_reference +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; +template +struct remove_pointer +{ + typedef T type; +}; + +template +constexpr T&& forward(typename remove_reference::type& t_) noexcept +{ + return static_cast(t_); +} +template +constexpr T&& forward(typename remove_reference::type&& t_) noexcept +{ + return static_cast(t_); +} + +template +struct is_const : public integral_constant +{ +}; +template +struct is_const : public integral_constant +{ +}; +template +inline constexpr bool is_const_v = is_const::value; + +template +inline constexpr bool is_reference_v = is_reference::value; + +template +struct remove_const +{ + typedef T type; +}; +template +struct remove_const +{ + typedef T type; +}; +template +using remove_const_t = typename remove_const::type; +template +inline constexpr bool is_class_v = is_class::value; + +template +inline constexpr bool is_trivially_copyable_v = is_trivially_copyable::value; +// template +// T&& declval() noexcept; + +template +U private_declval(int); + +template +T private_declval(long); + +template +auto declval() noexcept -> decltype(private_declval(0)); + +template +using void_t = void; +#else +#include +#include +using std::declval; +using std::forward; +using std::is_base_of; +using std::is_class; +using std::is_class_v; +using std::is_const_v; +using std::is_pointer; +using std::is_reference; +using std::is_reference_v; +using std::is_trivially_copyable; +using std::is_trivially_copyable_v; +using std::is_unsigned; +using std::remove_const_t; +using std::remove_cv; +using std::remove_pointer; +using std::remove_reference; +using std::void_t; +#endif + +template +struct is_same : public integral_constant +{ +}; + +template +struct is_same : public integral_constant +{ +}; + +template +struct is_floating_point : public integral_constant +{ +}; + +template <> +struct is_floating_point : public integral_constant +{ +}; + +template <> +struct is_floating_point : public integral_constant +{ +}; +template <> +struct is_floating_point : public integral_constant +{ +}; + +template +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template <> +struct is_integral : public integral_constant +{ +}; + +template +inline constexpr bool is_same_v = is_same::value; + +template +inline constexpr bool is_base_of_v = is_base_of::value; + +template +inline constexpr bool is_unsigned_v = is_unsigned::value; + +template +using remove_reference_t = typename remove_reference::type; + +template +using remove_cv_t = typename remove_cv::type; +template +using remove_cvref_t = remove_cv_t>; + +template +using remove_pointer_t = typename remove_pointer::type; + +template +inline constexpr bool is_pointer_v = is_pointer::value; + +template ::type = false> +__host__ __device__ constexpr Y bit_cast(const X& x) +{ + static_assert(__has_builtin(__builtin_bit_cast), ""); + static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type"); + + return __builtin_bit_cast(Y, x); +} +} // namespace ck diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index cda607bb30..51836fd60b 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -279,7 +279,6 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(half_t x) constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr int seed = 1254739; - #ifndef CK_CODE_GEN_RTC uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else @@ -344,7 +343,6 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(half_t x constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr int seed = 1254739; - #ifndef CK_CODE_GEN_RTC uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else @@ -1534,7 +1532,7 @@ inline __host__ __device__ float32_t type_convert(bf6x32_t #endif } -#ifndef CK_CODE_GEN_RTC +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template inline __host__ __device__ void array_convert(std::array& y, const std::array& x) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 107aae5516..4e0deb1547 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1309,7 +1309,9 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), "wrong! not implemented"); using rtn_type = thread_buffer; diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index 78768bbbfc..fa63597db4 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -153,12 +153,12 @@ struct array CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); } }; -template +template struct vector_traits; // specialization for array template -struct vector_traits> +struct vector_traits, void> { using scalar_type = T; static constexpr index_t vector_size = N; diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp index 279a48acb3..77c46e1b8c 100644 --- a/include/ck_tile/core/container/thread_buffer.hpp +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -149,17 +149,24 @@ struct thread_buffer { }; // clang-format on -template +template struct vector_traits; // specialization for array template -struct vector_traits> +struct vector_traits, std::enable_if_t>> { using scalar_type = T; static constexpr index_t vector_size = N; }; +template +struct vector_traits, std::enable_if_t>> +{ + using scalar_type = typename T::type; + static constexpr index_t vector_size = N; +}; + #endif } // namespace ck_tile diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 74575f4c6e..fd02177e25 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -294,7 +294,7 @@ struct tuple : impl::tuple_base, T...> #undef TP_COM_ }; -template +template struct vector_traits; // specialization for array diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 6ad38b1f7c..6f31468809 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -376,14 +376,12 @@ struct numeric } }; -template -struct numeric_traits; - template <> struct numeric_traits { - static constexpr int exp = 8; - static constexpr int mant = 7; + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int PackedSize = 1; }; #if CK_TILE_USE_CUSTOM_DATA_TYPE diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index c4fc6890c6..facc3e45ee 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -207,9 +207,6 @@ using bf8_t = unsigned _BitInt(8); using bf8_raw_t = uint8_t; #endif -template -struct numeric_traits; - template <> struct numeric_traits { @@ -225,6 +222,7 @@ struct numeric_traits static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ; #endif static constexpr uint8_t abs_mask = 0x7F; + static constexpr int PackedSize = 1; }; template <> @@ -242,6 +240,7 @@ struct numeric_traits static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ; #endif static constexpr uint8_t abs_mask = 0x7F; + static constexpr int PackedSize = 1; }; // below is sw fp8 conversion, not utilizing hw instruction diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 5779b170b7..8479b33f8f 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -223,9 +223,6 @@ struct numeric } }; -template -struct numeric_traits; - template <> struct numeric_traits { @@ -241,6 +238,7 @@ struct numeric_traits static constexpr uint16_t NegInf = 0xFC00; static constexpr uint16_t NaN = 0x7C01; static constexpr uint16_t Neg0 = 0x8000; + static constexpr int PackedSize = 1; using bitwise_type = uint16_t; }; @@ -383,4 +381,24 @@ half_t exp2(half_t x) { return static_cast(exp2f(static_cast(x))) CK_TILE_DEVICE half_t log(half_t x) { return static_cast(__logf(static_cast(x))); }; #endif + +using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); + +CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) +{ + fp16x2_t vector_res; + + vector_res.x = x.x + y.x; + vector_res.y = x.y + y.y; + + return vector_res; +} + +CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) +{ + fp16x2_t c; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y)); + return c; +} + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/int8.hpp b/include/ck_tile/core/numeric/int8.hpp index 9ca3333c39..34d9a1c4b9 100644 --- a/include/ck_tile/core/numeric/int8.hpp +++ b/include/ck_tile/core/numeric/int8.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" @@ -74,8 +74,6 @@ struct numeric }; #if 0 -template -struct numeric_traits; template <> struct numeric_traits @@ -91,6 +89,7 @@ struct numeric_traits static constexpr uint32_t NegInf = 0xFC00; static constexpr uint32_t NaN = 0x7C01; static constexpr uint32_t Neg0 = 0x8000; + static constexpr int PackedSize = 1; using bitwise_type = uint16_t; }; #endif diff --git a/include/ck_tile/core/numeric/numeric.hpp b/include/ck_tile/core/numeric/numeric.hpp index 6b16485b48..f125fbf2ce 100644 --- a/include/ck_tile/core/numeric/numeric.hpp +++ b/include/ck_tile/core/numeric/numeric.hpp @@ -77,7 +77,10 @@ struct numeric }; template -struct numeric_traits; +struct numeric_traits +{ + static constexpr int PackedSize = 1; +}; template <> struct numeric_traits @@ -94,6 +97,7 @@ struct numeric_traits static constexpr uint32_t NegInf = 0xFF800000; static constexpr uint32_t NaN = 0x7F800001; static constexpr uint32_t Neg0 = 0x80000000; + static constexpr int PackedSize = 1; using bitwise_type = uint32_t; }; diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 2ffcc36ced..541093e337 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -21,8 +21,8 @@ struct pk_int4_t { using type = int8_t; type data; - __host__ __device__ constexpr pk_int4_t() : data{type{}} {} - __host__ __device__ constexpr pk_int4_t(type init) : data{init} {} + CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {} + CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {} }; // limits @@ -91,6 +91,16 @@ struct numeric CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; } }; +template <> +struct numeric_traits +{ + static constexpr int PackedSize = 2; +}; + +using fp32x2_t = float __attribute__((ext_vector_type(2))); +using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); +using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); + CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) { uint8_t x_u8 = ck_tile::bit_cast(x); diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 480da96596..b165275a8c 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -10,6 +10,7 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -30,17 +31,34 @@ struct native_t // of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will // have compiler error namespace impl { + +template +struct ext_vector; + template -struct ext_vector +struct ext_vector::type>>> { static constexpr index_t N = N_; - using value_type = typename native_t>::type; + // struct type is not supported for ext_vector + using value_type = typename native_t::type; + static_assert(!std::is_class_v); + using type = value_type __attribute__((ext_vector_type(N))); // this is danguous +}; + +template +struct ext_vector::type>>> +{ + static constexpr index_t N = N_; + // struct type is not supported for ext_vector + using value_type = typename native_t::type::type; static_assert(!std::is_class_v); using type = value_type __attribute__((ext_vector_type(N))); // this is danguous }; template -struct ext_vector +struct ext_vector::type>>> { static constexpr index_t N = Vs_ * N_; using value_type = typename native_t>::type; @@ -48,6 +66,17 @@ struct ext_vector using type = value_type __attribute__((ext_vector_type(N))); // this is danguous }; +template +struct ext_vector::type>>> +{ + static constexpr index_t N = Vs_ * N_; + using value_type = typename native_t>::type::type; + static_assert(!std::is_class_v); + using type = value_type __attribute__((ext_vector_type(N))); // this is danguous +}; + } // namespace impl template @@ -55,10 +84,11 @@ using ext_vector_t = typename impl::ext_vector::type; // by default, any type will result in a vector_size=1 with scalar_type=T traits. // ... unless we have other vector_traits specialization -template +template struct vector_traits { - using scalar_type = remove_cvref_t; + using scalar_type = + std::conditional_t, pk_int4_t>, int8_t, remove_cvref_t>; static constexpr index_t vector_size = 1; }; @@ -66,7 +96,7 @@ struct vector_traits template struct vector_traits { - using scalar_type = T; + using scalar_type = std::conditional_t, int8_t, T>; static constexpr index_t vector_size = N; }; @@ -200,21 +230,11 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32))); using bf8x64_t = bf8_t __attribute((ext_vector_type(64))); #endif -CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) -{ - fp16x2_t vector_res; - - vector_res.x = x.x + y.x; - vector_res.y = x.y + y.y; - - return vector_res; -} - -CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) -{ - fp16x2_t c; - asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y)); - return c; -} - +// pk_int4_t +// using pk_int4_t +using pk_int4x2_t = int8_t __attribute((ext_vector_type(2))); +using pk_int4x4_t = int8_t __attribute((ext_vector_type(4))); +using pk_int4x8_t = int8_t __attribute((ext_vector_type(8))); +using pk_int4x16_t = int8_t __attribute((ext_vector_type(16))); +using pk_int4x32_t = int8_t __attribute((ext_vector_type(32))); } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 7dffa0e555..c2a093f1ab 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -231,13 +231,18 @@ struct buffer_view invalid_element_value_ = T{0}; + static constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + CK_TILE_HOST_DEVICE constexpr buffer_view() : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) - : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0} + : p_data_{p_data}, + buffer_size_{buffer_size / PackedSize}, + cached_buf_res_{0}, + invalid_element_value_{0} { } @@ -245,7 +250,7 @@ struct buffer_view>::scalar_type, - int8_t>::value && + if constexpr(std::is_same_v>::scalar_type, + int8_t> && workaround_int8_ds_write_issue) { if(is_valid_element) @@ -897,83 +902,117 @@ struct buffer_view" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix - static_assert((std::is_same, int8_t>::value && - std::is_same, int8_t>::value) || - (std::is_same, int8_t>::value && - std::is_same, int8x2_t>::value) || - (std::is_same, int8_t>::value && - std::is_same, int8x4_t>::value) || - (std::is_same, int8_t>::value && - std::is_same, int8x8_t>::value) || - (std::is_same, int8_t>::value && - std::is_same, int8x16_t>::value) || - (std::is_same, int8x4_t>::value && - std::is_same, int8x4_t>::value) || - (std::is_same, int8x8_t>::value && - std::is_same, int8x8_t>::value) || - (std::is_same, int8x16_t>::value && - std::is_same, int8x16_t>::value), - "wrong! not implemented for this combination, please add " - "implementation"); + static_assert( + (std::is_same_v, int8_t> && + std::is_same_v, int8_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, int8x2_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, int8x4_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, int8x8_t>) || + (std::is_same_v, int8_t> && + std::is_same_v, int8x16_t>) || + (std::is_same_v, int8x4_t> && + std::is_same_v, int8x4_t>) || + (std::is_same_v, int8x8_t> && + std::is_same_v, int8x8_t>) || + (std::is_same_v, int8x16_t> && + std::is_same_v, int8x16_t>) || + // ext_vector_type for pk_int4 must use int8_t as type + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x4_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x8_t> && + std::is_same_v, thread_buffer>) || + (std::is_same_v, pk_int4x16_t> && + std::is_same_v, thread_buffer>), + "wrong! not implemented for this combination, please add " + "implementation"); - if constexpr(std::is_same, int8_t>::value && - std::is_same, int8_t>::value) + if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8_t>::value && - std::is_same, int8x2_t>::value) + else if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8x2_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8_t>::value && - std::is_same, int8x4_t>::value) + else if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8x4_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8_t>::value && - std::is_same, int8x8_t>::value) + else if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8x8_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8_t>::value && - std::is_same, int8x16_t>::value) + else if constexpr((std::is_same_v, int8_t> && + std::is_same_v, int8x16_t>) || + (std::is_same_v, pk_int4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8x4_t>::value && - std::is_same, int8x4_t>::value) + else if constexpr((std::is_same_v, int8x4_t> && + std::is_same_v, int8x4_t>) || + (std::is_same_v, pk_int4x4_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8x8_t>::value && - std::is_same, int8x8_t>::value) + else if constexpr((std::is_same_v, int8x8_t> && + std::is_same_v, int8x8_t>) || + (std::is_same_v, pk_int4x8_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(std::is_same, int8x16_t>::value && - std::is_same, int8x16_t>::value) + else if constexpr((std::is_same_v, int8x16_t> && + std::is_same_v, int8x16_t>) || + (std::is_same_v, pk_int4x16_t> && + std::is_same_v, thread_buffer>)) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 8d2f88af39..b73a27c8d5 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -27,6 +27,8 @@ struct static_distributed_tensor using ThreadTensorDesc = remove_cvref_t; + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid"); @@ -59,7 +61,7 @@ struct static_distributed_tensor CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size() { - return kThreadElementSpaceSize; + return kThreadElementSpaceSize / PackedSize; } template @@ -79,8 +81,9 @@ struct static_distributed_tensor static_ford>{}([&](auto idx) { constexpr auto idx_ys = idx + sequence{}; - sliced_thread_data(number{}) = - thread_buf_[number{}]; + sliced_thread_data( + number{}) = + thread_buf_[number{}]; }); return sliced_thread_data; @@ -101,8 +104,9 @@ struct static_distributed_tensor static_ford>{}([&](auto idx) { constexpr auto idx_ys = idx + sequence{}; - thread_buf_(number{}) = - sliced_thread_data[number{}]; + thread_buf_(number{}) = + sliced_thread_data[number{}]; }); } @@ -115,7 +119,7 @@ struct static_distributed_tensor constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( TileDistributedIndices{}); - return thread_buf_[number{}]; + return thread_buf_[number{}]; } template @@ -127,11 +131,11 @@ struct static_distributed_tensor constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( TileDistributedIndices{}); - return thread_buf_(number{}); + return thread_buf_(number{}); } // - thread_buffer thread_buf_; + thread_buffer thread_buf_; }; template diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 4c72ed0859..336793c5b1 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,6 +45,8 @@ struct tensor_view using TensorIndex = array; using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); static constexpr auto DstInMemOp = DstInMemOp_; + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; CK_TILE_HOST_DEVICE constexpr tensor_view() = default; @@ -81,8 +83,8 @@ struct tensor_view bool_constant = {}) const { return buf_.template get( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } @@ -99,8 +101,8 @@ struct tensor_view bool is_valid_element, // flag bool_constant = {}) const { - return buf_.template get(coord.get_offset(), - linear_offset, + return buf_.template get(coord.get_offset() / PackedSize, + linear_offset / PackedSize, is_valid_element, bool_constant{}); } @@ -122,8 +124,8 @@ struct tensor_view { return buf_.template get_raw( dst, - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } @@ -142,8 +144,12 @@ struct tensor_view bool_constant = {}, bool_constant = {}) const { - return buf_.template get_raw( - dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant{}); + return buf_.template get_raw(dst, + coord.get_offset() / + PackedSize, + linear_offset / PackedSize, + is_valid_element, + bool_constant{}); } template ( smem, - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } @@ -178,8 +184,8 @@ struct tensor_view bool is_valid_element) const { return buf_.template async_get(smem, - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, is_valid_element, bool_constant{}); } @@ -198,8 +204,8 @@ struct tensor_view { return buf_.template async_get_raw( smem, - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), bool_constant{}); } @@ -217,8 +223,11 @@ struct tensor_view bool is_valid_element, bool_constant = {}) const { - return buf_.template async_get_raw( - smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant{}); + return buf_.template async_get_raw(smem, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, + is_valid_element, + bool_constant{}); } // X is vector of DataType. @@ -236,8 +245,8 @@ struct tensor_view bool_constant = {}) { buf_.template set( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } @@ -272,8 +281,8 @@ struct tensor_view bool_constant = {}) { buf_.template set_raw( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } @@ -292,7 +301,7 @@ struct tensor_view bool_constant = {}) { buf_.template set_raw( - coord.get_offset(), linear_offset, is_valid_element, x); + coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x); } // X is vector of DataType. @@ -310,8 +319,8 @@ struct tensor_view bool_constant = {}) { buf_.template update( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } @@ -330,7 +339,7 @@ struct tensor_view bool_constant = {}) { buf_.template update( - coord.get_offset(), linear_offset, is_valid_element, x); + coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x); } // X is vector of DataType. @@ -350,8 +359,8 @@ struct tensor_view bool_constant = {}) { buf_.template update_raw( - coord.get_offset(), - linear_offset, + coord.get_offset() / PackedSize, + linear_offset / PackedSize, coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), x); } @@ -372,7 +381,7 @@ struct tensor_view bool_constant = {}) { buf_.template update_raw( - coord.get_offset(), linear_offset, is_valid_element, x); + coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x); } CK_TILE_HOST_DEVICE void print() const diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 27c2c24ad5..3bb728df23 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -97,13 +97,15 @@ struct tile_window_with_static_distribution } public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); static constexpr index_t ScalarPerVector = get_vector_dim_y_scalar_per_vector().template at<1>(); // using vector_type_t = vector_type_maker_t; // using vector_t = typename vector_type_t::type; - using vector_t = thread_buffer; + using vector_t = thread_buffer; private: static constexpr auto scalars_per_access_ = [] { @@ -336,7 +338,7 @@ struct tile_window_with_static_distribution bottom_tensor_thread_coord, 0, bool_constant{}); #if 1 // write into distributed tensor - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -345,10 +347,11 @@ struct tile_window_with_static_distribution number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j]; + vec_value.template get_as()[j / Traits::PackedSize]; }); #else constexpr index_t d = @@ -390,8 +393,9 @@ struct tile_window_with_static_distribution using SFC_Ys = typename Traits::SFC_Ys; static constexpr index_t YElementSize = TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); - static_assert(YElementSize % Traits::ScalarPerVector == 0); - using vectorized_tbuf = array; + static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0); + using vectorized_tbuf = + array; // StaticBuffer( @@ -632,7 +637,7 @@ struct tile_window_with_static_distribution // vector_type_t vec; vector_t vec_value; - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -641,9 +646,10 @@ struct tile_window_with_static_distribution number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -698,7 +704,7 @@ struct tile_window_with_static_distribution // read from distributed tensor vector_t vec_value; - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -706,8 +712,9 @@ struct tile_window_with_static_distribution }, number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); - vec_value.template get_as()(j) = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -759,7 +766,7 @@ struct tile_window_with_static_distribution // read from distributed tensor vector_t vec_value; - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -768,9 +775,10 @@ struct tile_window_with_static_distribution number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -825,7 +833,7 @@ struct tile_window_with_static_distribution // read from distributed tensor vector_t vec_value; - static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) @@ -834,9 +842,10 @@ struct tile_window_with_static_distribution number{}); constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 96a8352c04..1e24e660f6 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core/arch/arch.hpp" @@ -151,11 +151,13 @@ struct tile_window_linear } public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); static constexpr index_t ScalarPerVector = get_vector_dim_y_scalar_per_vector().template at<1>(); - using vector_t = thread_buffer; + using vector_t = thread_buffer; private: static constexpr auto scalars_per_access_ = [] { @@ -498,17 +500,18 @@ struct tile_window_linear // data index [y0, y1, ...] constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); // write into distributed tensor - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j]; + vec_value.template get_as()[j / traits::PackedSize]; }); #else constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); @@ -556,17 +559,18 @@ struct tile_window_linear // data index [y0, y1, ...] constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); // write into distributed tensor - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j]; + vec_value.template get_as()[j / traits::PackedSize]; }); #else constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); @@ -595,8 +599,9 @@ struct tile_window_linear using SFC_Ys = typename traits::SFC_Ys; static constexpr index_t YElementSize = TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); - static_assert(YElementSize % traits::ScalarPerVector == 0); - using vectorized_tbuf = array; + static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0); + using vectorized_tbuf = + array; constexpr auto tile_dstr = TileDstr{}; @@ -620,7 +625,9 @@ struct tile_window_linear // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) / + traits::PackedSize; static_assert(d % traits::ScalarPerVector == 0); get_bottom_tensor_view().template get_vectorized_elements_raw( @@ -804,16 +811,17 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -852,14 +860,15 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); - vec_value.template get_as()(j) = + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; + vec_value.template get_as()(j / traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -897,16 +906,17 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); @@ -948,16 +958,17 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, number{}); - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + traits::PackedSize; - vec_value.template get_as()(j) = + vec_value.template get_as()(j / traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index ea70563d58..745c18d6dd 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -29,11 +29,12 @@ double get_relative_threshold(const int number_of_accumulations = 1) using I8 = int8_t; using I32 = int32_t; - static_assert(is_any_of::value, - "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); + static_assert( + is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -42,11 +43,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) compute_error = std::pow(2, -numeric_traits::mant) * 0.5; } - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -56,11 +57,11 @@ double get_relative_threshold(const int number_of_accumulations = 1) } double midway_error = std::max(compute_error, output_error); - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -82,12 +83,13 @@ double get_absolute_threshold(const double max_possible_num, const int number_of using I8 = int8_t; using I32 = int32_t; - static_assert(is_any_of::value, - "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); + static_assert( + is_any_of::value, + "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -96,11 +98,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of compute_error = std::pow(2, expo - numeric_traits::mant) * 0.5; } - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } @@ -110,11 +112,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of } double midway_error = std::max(compute_error, output_error); - static_assert(is_any_of::value, + static_assert(is_any_of::value, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; - if constexpr(is_any_of::value) + if constexpr(is_any_of::value) { return 0; } diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index f24c338755..006026470b 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -282,7 +282,14 @@ struct FillMonotonicSeq { std::generate(first, last, [=, n = init_value_]() mutable { auto tmp = n; - n += step_; + if constexpr(std::is_same_v) + { + n.data += step_.data; + } + else + { + n += step_; + } return tmp; }); } diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 2047ad7793..a43877c6da 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -281,18 +281,18 @@ struct HostTensor using Data = std::vector; template - HostTensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.get_element_space_size()) + HostTensor(std::initializer_list lens) : mDesc(lens), mData(get_element_space_size()) { } template HostTensor(std::initializer_list lens, std::initializer_list strides) - : mDesc(lens, strides), mData(mDesc.get_element_space_size()) + : mDesc(lens, strides), mData(get_element_space_size()) { } template - HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size()) + HostTensor(const Lengths& lens) : mDesc(lens), mData(get_element_space_size()) { } @@ -302,7 +302,7 @@ struct HostTensor { } - HostTensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.get_element_space_size()) {} + HostTensor(const Descriptor& desc) : mDesc(desc), mData(get_element_space_size()) {} template HostTensor CopyAsType() const @@ -340,7 +340,11 @@ struct HostTensor std::size_t get_element_size() const { return mDesc.get_element_size(); } - std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); } + std::size_t get_element_space_size() const + { + constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + return mDesc.get_element_space_size() / PackedSize; + } std::size_t get_element_space_size_in_bytes() const { @@ -463,29 +467,27 @@ struct HostTensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - return mDesc.GetOffsetFromMultiIndex(is...); + constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; + return mDesc.GetOffsetFromMultiIndex(is...) / PackedSize; } template T& operator()(Is... is) { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + return mData[GetOffsetFromMultiIndex(is...)]; } template const T& operator()(Is... is) const { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + return mData[GetOffsetFromMultiIndex(is...)]; } - T& operator()(std::vector idx) - { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; - } + T& operator()(std::vector idx) { return mData[GetOffsetFromMultiIndex(idx)]; } const T& operator()(std::vector idx) const { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + return mData[GetOffsetFromMultiIndex(idx)]; } HostTensor transpose(std::vector axes = {}) const diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index da0de457d4..fe5077083c 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,11 +34,35 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, for(std::size_t k = 0; k < K; ++k) { - ADataType v_a = a_element_op(a_m_k(m, k)); - BDataType v_b = b_element_op(b_k_n(k, n)); - - v_acc += - ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + } + if constexpr(std::is_same_v) + { + const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else + { + v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + } + v_acc += v_a * v_b; } c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); @@ -73,6 +97,8 @@ __global__ void naive_gemm_kernel(ADataType* A, AccDataType acc = 0.0; for(int k = 0; k < K; ++k) { + constexpr index_t packed_size_a = ck_tile::numeric_traits::PackedSize; + constexpr index_t packed_size_b = ck_tile::numeric_traits::PackedSize; // Adjust indexing based on matrix layout int a_index = (std::is_same_v) ? row * strideA + k @@ -80,8 +106,34 @@ __global__ void naive_gemm_kernel(ADataType* A, int b_index = (std::is_same_v) ? col * strideB + k : k * strideB + col; - acc += ck_tile::type_convert(A[a_index]) * - ck_tile::type_convert(B[b_index]); + + AccDataType v_a; + AccDataType v_b; + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]); + if(k % 2 == 1) + v_a = fp32_val.hi; + else + v_a = fp32_val.lo; + } + else + { + v_a = ck_tile::type_convert(A[a_index]); + } + if constexpr(std::is_same_v) + { + const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]); + if(k % 2 == 1) + v_b = fp32_val.hi; + else + v_b = fp32_val.lo; + } + else + { + v_b = ck_tile::type_convert(B[b_index]); + } + acc += v_a * v_b; } int c_index = (std::is_same_v) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp index 24f35d3636..64e5224780 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp @@ -157,7 +157,7 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass sweep_tile(qy, [&, yscale_ = yscale](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); auto qy_ = y[idx] / yscale_[i_idx]; - qy(idx) = saturates{}(qy_); + qy(idx) = type_convert(saturates{}(qy_)); }); store_tile(qy_window, qy); } diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp index aec7368e27..ecd4e81b22 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp @@ -260,7 +260,7 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass const auto x_ = type_convert(x[idx]); auto y_ = x_ * inv_rms[i_idx] * gamma_; auto qy_ = y_ / yscale[i_idx]; - qy(idx) = saturates{}(qy_); + qy(idx) = type_convert(saturates{}(qy_)); }); store_tile(qy_window, qy); diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 3e8dac30ef..a3a0df996d 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,20 +9,166 @@ namespace ck_tile { namespace element_wise { -#if 0 +// Fast int4x4 to fp16x8_t data type conversion based on paper +// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] +// (https://arxiv.org/abs/2211.10017) and implementation: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +CK_TILE_DEVICE fp16x4_t i4_to_half4(int q) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + int lo; + int hi; + // Extract the two int4 at low bit and create two fp16 number. + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX)); + // Extract the two int4 at hight bit and create two fp16 number. + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX)); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + fp16x4_t res; + + // for two fp16 from lowbit, subtract 1032 to get correct fp16 value + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(res.lo) + : "v"(bit_cast(lo)), "v"(bit_cast(SUB))); + + // for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value + asm volatile( + "v_pk_fma_f16 %0, %1, %2, %3" + : "=v"(res.hi) + : "v"(bit_cast(hi)), "v"(bit_cast(MUL)), "v"(bit_cast(ADD))); + + return res; +} + +CK_TILE_DEVICE fp16x4_t i4_to_half4_scale(int q, const fp16x2_t& scale) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + int lo; + int hi; + // Extract the two int4 at low bit and create two fp16 number. + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(lo) : "v"(q), "v"(LO), "v"(EX)); + // Extract the two int4 at hight bit and create two fp16 number. + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(hi) : "v"(q), "v"(HI), "v"(EX)); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + fp16x4_t res; + + asm volatile("v_pk_add_f16 %0, %1, %2" + : "=v"(res.lo) + : "v"(bit_cast(lo)), "v"(bit_cast(SUB))); + + asm volatile( + "v_pk_fma_f16 %0, %1, %2, %3" + : "=v"(res.hi) + : "v"(bit_cast(hi)), "v"(bit_cast(MUL)), "v"(bit_cast(ADD))); + + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.lo) : "v"(res.lo), "v"(scale)); + + asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(res.hi) : "v"(res.hi), "v"(scale)); + + return res; +} + +CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) +{ + uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); + + static constexpr uint32_t fp32_base = 0x4B000000; + + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388616.f; + fp32_intermediates[1] -= 8388616.f; + fp32_intermediates[2] -= 8388616.f; + fp32_intermediates[3] -= 8388616.f; + + bf16x4_t res; + res.lo = bit_cast( + __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632)); + res.hi = bit_cast( + __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632)); + + return res; +} + +struct PassThroughPack8 +{ + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; + + CK_TILE_HOST_DEVICE constexpr void operator()(fp16x8_t& y, const pk_int4x4_t& x) const + { + y.lo = i4_to_half4(bit_cast(x)); + y.hi = i4_to_half4(bit_cast(x) >> 8); + } + + CK_TILE_HOST_DEVICE constexpr void operator()(bf16x8_t& y, const pk_int4x4_t& x) const + { + y.lo = i4_to_bhalf4(bit_cast(x)); + y.hi = i4_to_bhalf4(bit_cast(x) >> 16); + } + constexpr const static bool is_pack8_invocable = true; +}; + +struct DequantPack8 +{ + template + CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const; + + CK_TILE_HOST_DEVICE constexpr void + operator()(fp16x8_t& y, const pk_int4x4_t& x, const fp16x2_t& z) const + { + y.lo = i4_to_half4_scale(bit_cast(x), z); + y.hi = i4_to_half4_scale(bit_cast(x) >> 8, z); + } + + constexpr const static bool is_pack8_invocable = true; +}; + struct PassThroughPack2 { template CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const; - CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const +#if 0 + CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::fp16x2_t& y, const ck_tile::f8x2_t& x) const { auto t = type_convert(x); - y = type_convert(t); + y = type_convert(t); } +#endif + + CK_TILE_HOST_DEVICE constexpr void operator()(fp16x2_t& y, const pk_int4_t& x) const + { + uint8_t x_u8 = bit_cast(x); + uint8_t x_l = (x_u8 & 0x0f) >> 0; + uint8_t x_h = (x_u8 & 0xf0) >> 4; + + y.lo = type_convert(x_l); + y.hi = type_convert(x_h); + } + constexpr const static bool is_pack2_invocable = true; }; -#endif struct PassThrough { diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 1569c93565..726543b97a 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -310,7 +310,7 @@ struct SimplifiedGenericAttentionMask const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); const index_t split_start = x_per_split * i_split; - const index_t split_end = split_start + x_per_split; + const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), ck_tile::min(origin_end, split_end)); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 92dc2bac3f..14d0596287 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -742,7 +742,7 @@ struct FmhaFwdSplitKVKernel return pad_tensor_view( v_dram_transposed, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } else { diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index ab21398b99..d9d6739fb5 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/elementwise.hpp" namespace ck_tile { @@ -20,12 +21,13 @@ struct BlockUniversalGemmAsBsCr template struct GemmTraits_ { - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr auto Scheduler = Problem::Scheduler; @@ -71,10 +73,10 @@ struct BlockUniversalGemmAsBsCr using BWarpTileDistr = remove_cvref_t; - using AWarpTile = - remove_cvref_t(AWarpTileDistr{}))>; - using BWarpTile = - remove_cvref_t(BWarpTileDistr{}))>; + using AWarpTile = remove_cvref_t( + AWarpTileDistr{}))>; + using BWarpTile = remove_cvref_t( + BWarpTileDistr{}))>; // TODO: Should we have two policies? Interwave & Intrawave ?? static constexpr index_t InterWaveSchedulingMacClusters = 1; @@ -90,9 +92,10 @@ struct BlockUniversalGemmAsBsCr public: using Traits = GemmTraits_; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; using WarpGemm = remove_cvref_t; @@ -105,10 +108,34 @@ struct BlockUniversalGemmAsBsCr static constexpr auto Scheduler = Traits::Scheduler; + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + using I0 = number<0>; using I1 = number<1>; private: + template + CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window, + WarpTile& warp_tile) + { + constexpr index_t UnaryOpSize = 8; + const element_wise::PassThroughPack8 elementwise_op{}; + constexpr index_t thread_buffer_size = + Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); + + static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + + using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), + in_dstr_tensors.get_thread_buffer().template get_as()[i]); + }); + } + template struct BlockGemmImpl { @@ -208,6 +235,8 @@ struct BlockUniversalGemmAsBsCr }); using CWarpDstr = typename WarpGemm::CWarpDstr; + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor; constexpr auto c_warp_y_lengths = @@ -217,10 +246,26 @@ struct BlockUniversalGemmAsBsCr // hot loop: static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); + AWarpTensor a_warp_tile; + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile); + } + else + { + a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); + } static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); + BWarpTensor b_warp_tile; + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile); + } + else + { + b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); + } // read C warp tensor from C block tensor- CWarpTensor c_warp_tensor; @@ -342,11 +387,27 @@ struct BlockUniversalGemmAsBsCr static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block window - load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_windows(mIter)(kIter), + a_warp_tiles_(mIter)(kIter)); + } + else + { + a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter)); + } }); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window - load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_windows(nIter)(kIter), + b_warp_tiles_(nIter)(kIter)); + } + else + { + b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); + } }); }); } @@ -504,12 +565,27 @@ struct BlockUniversalGemmAsBsCr // TODO check if a_warp_tiles has same desc as a_warp_window static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_windows(mIter)(kIter), + a_warp_tiles_(mIter)(kIter)); + } + else + { + a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter)); + } }); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window - load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_windows(nIter)(kIter), + b_warp_tiles_(nIter)(kIter)); + } + else + { + b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); + } }); }); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 69c50c7cd0..73d5ce8f81 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -54,6 +54,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; @@ -196,12 +201,12 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // A/B split schedule // compiler is likely to use ds_read2 when instruction width smaller than 16bytes - constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16 - ? A_LDS_Read_Inst_Num - : A_LDS_Read_Inst_Num / 2; - constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16 - ? B_LDS_Read_Inst_Num - : B_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; @@ -213,9 +218,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; constexpr auto ds_read_a_issue_cycle = - A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; constexpr auto ds_read_b_issue_cycle = - B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; constexpr auto ds_read_a_mfma_rate = (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); constexpr auto ds_read_b_mfma_rate = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index ea8d063fd5..b679f8c8aa 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -60,6 +60,13 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; @@ -139,12 +146,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); - constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16 - ? A_LDS_Read_Inst_Num - : A_LDS_Read_Inst_Num / 2; - constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16 - ? B_LDS_Read_Inst_Num - : B_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b; constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index cde31f087b..b8b2d5b1c9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -21,6 +21,13 @@ struct BaseGemmPipelineAgBgCrMem using BDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } static constexpr index_t BlockSize = Problem::kBlockSize; @@ -33,9 +40,11 @@ struct BaseGemmPipelineAgBgCrMem static constexpr index_t WgpPerCU = (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1; - static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil( - MinMemInFlyBytes / WgpPerCU, - (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t FullMemBandPrefetchStages = + integer_divide_ceil(MinMemInFlyBytes / WgpPerCU, + (MPerBlock * sizeof(ADataType) / APackedSize + + NPerBlock * sizeof(BDataType) / BPackedSize) * + KPerBlock); static constexpr index_t PrefetchStages = FullMemBandPrefetchStages >= 2 ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 2d9f95627c..c7115c8eb4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -67,16 +67,22 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * - MakeALdsBlockDescriptor().get_element_space_size(); + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + constexpr index_t smem_size_a = + sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size() / PackedSize; return smem_size_a; } template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * - MakeBLdsBlockDescriptor().get_element_space_size(); + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + constexpr index_t smem_size_b = + sizeof(typename Problem::BDataType) * + MakeBLdsBlockDescriptor().get_element_space_size() / PackedSize; return smem_size_b; } @@ -387,8 +393,8 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy using AccDataType = float; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher; using BlockGemmShape = remove_cvref_t; + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; @@ -37,13 +42,15 @@ struct GemmPipelineAGmemBGmemCRegV2 CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() { - return integer_divide_ceil( - sizeof(ADataType) * - Policy::template MakeALdsBlockDescriptor().get_element_space_size(), - 16) * + return integer_divide_ceil(sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor() + .get_element_space_size() / + APackedSize, + 16) * 16 + sizeof(BDataType) * - Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + Policy::template MakeBLdsBlockDescriptor().get_element_space_size() / + BPackedSize; } template (p_a_lds, a_lds_block_desc); constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + integer_divide_ceil( + sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / APackedSize, 16) * 16; // B tile in LDS diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 771662f566..f833ccc849 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -13,14 +13,16 @@ template + typename Traits_, + typename ComputeDataType_ = ADataType_> struct GemmPipelineProblemBase { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -53,13 +55,15 @@ struct GemmPipelineProblemBase CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() { + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; - return pixels_per_thread < VectorLoadSize / sizeof(ADataType) + return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType) ? pixels_per_thread - : VectorLoadSize / sizeof(ADataType); + : PackedSize * VectorLoadSize / sizeof(ADataType); } else { @@ -69,17 +73,19 @@ struct GemmPipelineProblemBase CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB() { + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; - return pixels_per_thread < VectorLoadSize / sizeof(BDataType) + return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType) ? pixels_per_thread - : VectorLoadSize / sizeof(BDataType); + : PackedSize * VectorLoadSize / sizeof(BDataType); } else { - return VectorLoadSize / sizeof(BDataType); + return PackedSize * VectorLoadSize / sizeof(BDataType); } } @@ -143,9 +149,14 @@ template -using GemmPipelineProblem = - GemmPipelineProblemBase; + typename Traits_, + typename ComputeDataType_ = ADataType_> +using GemmPipelineProblem = GemmPipelineProblemBase; template + TailNumber TailNum_ = TailNumber::Full, + typename ComputeDataType_ = ADataType_> struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index c20d09cea4..fd1e76a02b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -34,31 +34,41 @@ struct UniversalGemmBasePolicy constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; // Assume DataType is even! - if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 && - elements_per_thread % (16 / sizeof(DataType)) == 0) + if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 && + PackedSize == 2) { - return (16 / sizeof(DataType)); + return (PackedSize * 32 / sizeof(DataType)); } - else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 && - elements_per_thread % (8 / sizeof(DataType)) == 0) + else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0) { - return (8 / sizeof(DataType)); + return (PackedSize * 16 / sizeof(DataType)); } - else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 && - elements_per_thread % (4 / sizeof(DataType)) == 0) + else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0) { - return (4 / sizeof(DataType)); + return (PackedSize * 8 / sizeof(DataType)); } - else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 && - elements_per_thread % (2 / sizeof(DataType)) == 0) + else if constexpr(sizeof(DataType) >= PackedSize * 4 && + XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0) { - return (2 / sizeof(DataType)); + return (PackedSize * 4 / sizeof(DataType)); + } + else if constexpr(sizeof(DataType) >= PackedSize * 2 && + XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0) + { + return (PackedSize * 2 / sizeof(DataType)); } else { - return 1; + return PackedSize; } } @@ -564,8 +574,8 @@ struct UniversalGemmPipelineAgBgCrPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher, @@ -31,7 +31,85 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst PassThrough, MultiplyMultiply>>>& instances); -void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2( std::vector, @@ -177,81 +255,239 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_in #endif #endif -#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8)) -void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances( +#ifdef CK_ENABLE_FP16 +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1( std::vector, Row, - I8, - I8, + F8, + F8, Tuple, - BF16, + F16, PassThrough, PassThrough, MultiplyMultiply>>>& instances); -void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances( +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part1( std::vector, Row, - I8, - I8, + F8, + F8, Tuple, - BF16, + F16, PassThrough, PassThrough, MultiplyMultiply>>>& instances); -void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances( +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part2( std::vector, Row, - I8, - I8, + F8, + F8, Tuple, - BF16, + F16, PassThrough, PassThrough, MultiplyMultiply>>>& instances); -void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances( +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part2( std::vector, Row, - I8, - I8, + F8, + F8, Tuple, - BF16, + F16, PassThrough, PassThrough, MultiplyMultiply>>>& instances); -void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances( +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part1( std::vector, Row, - I8, - I8, + F8, + F8, Tuple, - BF16, + F16, PassThrough, PassThrough, MultiplyMultiply>>>& instances); -void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances( +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); +#endif + +#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8)) +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances( std::vector, Row, I8, I8, - Tuple, - BF16, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, PassThrough, PassThrough, MultiplyMultiply>>>& instances); @@ -261,6 +497,7 @@ void add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_i template @@ -271,7 +508,7 @@ struct DeviceOperationInstanceFactory, + DsDataType, CDataType, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, @@ -284,7 +521,7 @@ struct DeviceOperationInstanceFactory, + DsDataType, CDataType, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, @@ -302,9 +539,22 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( + add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1( op_ptrs); - add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( + add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part1( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part2( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part2( + op_ptrs); + + add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instances_part1( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instances_part2( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2( op_ptrs); add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( @@ -326,9 +576,22 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances( + add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1( op_ptrs); - add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances( + add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part1( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part2( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part2( + op_ptrs); + + add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part1( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part2( + op_ptrs); + add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2( op_ptrs); add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances( @@ -344,26 +607,26 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances( + add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances( op_ptrs); - add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances( + add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instances( op_ptrs); - add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances( + add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instances( op_ptrs); - add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances( + add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instances( op_ptrs); - add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances( + add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instances( op_ptrs); - add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances( + add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instances( op_ptrs); } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp new file mode 100644 index 0000000000..90da3ad0fe --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp @@ -0,0 +1,317 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8)) +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); +#endif + +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& + instances); +#endif + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleDSplitKBPreShuffle< + ALayout, + BLayout, + Tuple, + CLayout, + ADataType, + BDataType, + Tuple, + CDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::MultiplyMultiply>> +{ + using DeviceOp = + DeviceGemmMultipleDSplitKBPreShuffle, + CLayout, + ADataType, + BDataType, + Tuple, + CDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::MultiplyMultiply>; + + static auto GetInstances() + { + std::vector> op_ptrs; +// TODO: Add MFMA layout into tensor layout +#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( + op_ptrs); + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( + op_ptrs); + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( + op_ptrs); + + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( + op_ptrs); + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( + op_ptrs); + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( + op_ptrs); + + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances( + op_ptrs); + } + } +#endif + +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances( + op_ptrs); + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( + op_ptrs); + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( + op_ptrs); + + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2( + op_ptrs); + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2( + op_ptrs); + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2( + op_ptrs); + + add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances( + op_ptrs); + } + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp new file mode 100644 index 0000000000..b445e0001d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif + +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> + //clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp index a08d73546d..a493719637 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -45,31 +45,49 @@ template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + +template using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| MaxTranspose| MaxTranspose| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| TransferSrc| TransferDst| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instances for small conv.K and conv.C - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -78,33 +96,51 @@ template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_generic_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +template using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| MaxTranspose| MaxTranspose| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| TransferSrc| TransferDst| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -113,31 +149,49 @@ template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_generic_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + +template using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| MaxTranspose| MaxTranspose| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| TransferSrc| TransferDst| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -145,34 +199,36 @@ template + ConvolutionBackwardWeightSpecialization ConvSpec, + index_t MaxTransposeTransferSrcScalarPerVector = 1, + index_t MaxTransposeTransferDstScalarPerVector = 1> using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| MaxTranspose| MaxTranspose| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| TransferSrc| TransferDst| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K // for bf16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -180,35 +236,37 @@ template + ConvolutionBackwardWeightSpecialization ConvSpec, + index_t MaxTransposeTransferSrcScalarPerVector = 1, + index_t MaxTransposeTransferDstScalarPerVector = 1> using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| MaxTranspose| MaxTranspose| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| TransferSrc| TransferDst| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> #endif // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index 410abe366c..f9b3b43795 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -61,8 +61,9 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) -#else +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#else // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -81,7 +82,7 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> -#endif // defined(__gfx950__) +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -97,8 +98,9 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) -#else +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#else // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -113,7 +115,7 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> -#endif // defined(__gfx950__) +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -148,8 +150,9 @@ using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple< //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) -#else +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#else // defined(CK_USE_AMD_MFMA_GFX950) DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, @@ -160,7 +163,7 @@ using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> -#endif // defined(__gfx950__) +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 09a489cd04..9114d5c1fb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -45,13 +45,16 @@ using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) -#else +#if defined(CK_USE_AMD_MFMA_GFX950) + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> +#else // defined(CK_USE_AMD_MFMA_GFX950) // Instances with NumGroupsPerBatch > 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> -#endif // defined(__gfx950__) +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; @@ -67,13 +70,17 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) -#else +#if defined(CK_USE_AMD_MFMA_GFX950) + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> +#else // defined(CK_USE_AMD_MFMA_GFX950) // Instances with NumGroupsPerBatch > 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> -#endif // defined(__gfx950__) +#endif // defined(CK_USE_AMD_MFMA_GFX950) // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index f4cc7da5e1..f1993eb149 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -311,6 +311,11 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( + op_ptrs); } #endif } @@ -471,6 +523,15 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances( + op_ptrs); } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc index ca28a4b516..31a536c7bb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -74,6 +74,30 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( @@ -87,6 +111,30 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( @@ -101,6 +149,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instances( + std::vector>>& instances); void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances( + std::vector>>& instances); #endif // conv3d backward weight #ifdef CK_ENABLE_BF16 @@ -384,6 +611,66 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances( + std::vector>>& instances); #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt index 3b63c2c160..57bae7a2ac 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt @@ -2,35 +2,59 @@ set(GEMM_MULTIPLY_MULTIPLY_INSTANCES) list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES - device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp - device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp + device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part1.cpp + device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part2.cpp + device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part2.cpp + device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp + device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp + device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp + device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp - device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance.cpp - device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp + device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp + device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part2.cpp + device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part2.cpp + device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp + device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp + device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp + device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp - device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp - device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp - device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp - device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp - device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp - device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp ) -set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_multiply_multiply_instance ${GEMM_MULTIPLY_MULTIPLY_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp index 6d8d93ca79..17da9ce5b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp @@ -35,7 +35,7 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple< +using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1 = std::tuple< // clang-format off //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| @@ -44,15 +44,13 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std #if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) // Compute friendly DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 32, 32, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 32, 32, 2, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 32, 32, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 32, 32, 1, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, @@ -61,7 +59,28 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 32, 32, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 32, 32, 3, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 96, 128, 128, 16, 16, 16, 16, 3, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 96, 128, 256, 16, 16, 16, 16, 3, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 80, 128, 256, 16, 16, 16, 16, 5, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 192, 256, 16, 16, 32, 32, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, @@ -80,28 +99,85 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 192, 256, 16, 16, 16, 16, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; +template +using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part1 = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + // 256x[64, 256, 32]x128 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 8, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // 224x[64, 256, 32]x128 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 192x[64, 256, 32]x128, 192x[64]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 6, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 256, 16, 16, 16, 16, 6, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> #endif - // clang-format on - >; + // clang-format on + >; + +template +using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part2 = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + + // 160x[64, 256, 32]x128, 160x[64, 96, 32]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 5, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 256, 16, 16, 16, 16, 5, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 256, 16, 16, 16, 16, 5, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 128x[64, 256, 32]x128, 128x[64, 128, 32]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; template using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple< @@ -136,16 +212,18 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std: DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + + // v3 more DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + +// clang-format on #endif - // clang-format on >; } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp index 6527d93473..f0e0d3d689 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1( std::vector, @@ -23,7 +23,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst { add_device_operation_instances( instances, - device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances{}); + device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part2.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part2.cpp index 7f16a7a2c5..9fec958841 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part2.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part2( std::vector, @@ -23,7 +23,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_ins { add_device_operation_instances( instances, - device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances{}); + device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part1.cpp new file mode 100644 index 0000000000..686dcb67e1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part1.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part2.cpp new file mode 100644 index 0000000000..e9d1183199 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part2.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp new file mode 100644 index 0000000000..687b95f07a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part1< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp new file mode 100644 index 0000000000..3f2ffbf231 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part2< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp new file mode 100644 index 0000000000..9b6ec0fdba --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part1< + GemmKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp new file mode 100644 index 0000000000..e103ecfce1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part2< + GemmKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp index 1730eba3cc..5c854ee5d9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp @@ -35,7 +35,7 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances = std::tuple< +using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part1 = std::tuple< // clang-format off //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| @@ -44,9 +44,6 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances = std: #if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) // Compute friendly DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 32, 32, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, @@ -61,6 +58,20 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances = std: DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 32, 32, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part2 = std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, @@ -80,28 +91,85 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances = std: DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 192, 256, 16, 16, 16, 16, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; +template +using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part1 = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + // 256x[64, 256, 32]x128 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 8, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // 224x[64, 256, 32]x128 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 192x[64, 256, 32]x128, 192x[64]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 6, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 256, 16, 16, 16, 16, 6, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> #endif - // clang-format on - >; + // clang-format on + >; + +template +using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part2 = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + + // 160x[64, 256, 32]x128, 160x[64, 96, 32]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 5, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 256, 16, 16, 16, 16, 5, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 256, 16, 16, 16, 16, 5, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 128x[64, 256, 32]x128, 128x[64, 128, 32]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; template using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_instances = std::tuple< @@ -136,16 +204,18 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_instances = std:: DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + + // v3 more DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + +// clang-format on #endif - // clang-format on >; } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp new file mode 100644 index 0000000000..463bbd1570 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part1{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part2.cpp new file mode 100644 index 0000000000..a78ba4692b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part2.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part2{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp new file mode 100644 index 0000000000..f937a607d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part1{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part2.cpp new file mode 100644 index 0000000000..fa2b5ba6e7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part2.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part2{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp new file mode 100644 index 0000000000..38b19178dc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part1< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp new file mode 100644 index 0000000000..95fda4fbfb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part2< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp new file mode 100644 index 0000000000..4a86e2e8d5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part1< + GemmKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp new file mode 100644 index 0000000000..19c53e6287 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part2< + GemmKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..04d44ad6b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using I32 = int; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using MultiplyMultiply = element_wise::MultiplyMultiply; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 32, 32, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 32, 32, 2, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 32, 32, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 32, 32, 1, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 32, 32, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 32, 32, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 224, 128, 16, 16, 16, 16, 2, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 192, 256, 16, 16, 32, 32, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 192, 128, 16, 16, 32, 32, 1, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 160, 256, 16, 16, 16, 16, 2, 5, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 96, 256, 16, 16, 16, 16, 2, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 224, 256, 16, 16, 16, 16, 1, 7, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 192, 256, 16, 16, 16, 16, 1, 6, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 160, 256, 16, 16, 16, 16, 1, 5, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 96, 256, 16, 16, 16, 16, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 64, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 192, 256, 16, 16, 16, 16, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, I8> + + // clang-format oI + >; + +template +using device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, I8>, + // Memory friendly + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8>, + DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple, Row, I8, I8, Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, I8> + // clang-format oI + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..1162b7a110 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..ac2994e868 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..5fb2d06543 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..e948057284 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..82cc25f4eb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000..8c3c2d2c58 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector, + Row, + I8, + I8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/CMakeLists.txt new file mode 100644 index 0000000000..943b2bf4c7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/CMakeLists.txt @@ -0,0 +1,42 @@ +# ONLY XDL_KERNELS +set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES) + +list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance.cpp + + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance.cpp + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance.cpp + ) + +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +add_instance_library(device_gemm_multiply_multiply_weight_preshuffle_instance ${GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance.cpp new file mode 100644 index 0000000000..71383f5dc1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp new file mode 100644 index 0000000000..a138452295 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using MultiplyMultiply = element_wise::MultiplyMultiply; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto v1 = BlockGemmPipelineVersion::v1; +static constexpr auto v2 = BlockGemmPipelineVersion::v2; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + // N 256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + // N 512 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 128, 16, 16, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 32, 32, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 512, 16, 16, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + // N 256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 256, 16, 16, 32, 32, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 256, 16, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 512, 16, 16, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 512, 16, 16, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + // N 512 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 256, 16, 16, 32, 32, 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 256, 16, 16, 32, 32, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_instances = + std::tuple< +// clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 32, 32, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 32, 32, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 32, 32, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 32, 32, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 32, 32, 6, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances = + std::tuple< +// clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + // 256x[64, 256, 32]x128 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 8, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 224x[64, 256, 32]x128 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 192x[64, 256, 32]x128, 192x[64]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 6, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 160x[64, 256, 32]x128, 160x[64, 96, 32]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 5, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 128x[64, 256, 32]x128, 128x[64, 128, 32]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance.cpp new file mode 100644 index 0000000000..e3ff079d99 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_instances< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp new file mode 100644 index 0000000000..6e9b3ea172 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances< + v1, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp new file mode 100644 index 0000000000..cc543b19c1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances< + v2, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp new file mode 100644 index 0000000000..8557d0c80e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances< + v1, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp new file mode 100644 index 0000000000..9fcce478e7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances< + v2, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp new file mode 100644 index 0000000000..84c2c70e35 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances< + v1, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp new file mode 100644 index 0000000000..0933b1fe18 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances< + v2, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance.cpp new file mode 100644 index 0000000000..05529f9cdd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp new file mode 100644 index 0000000000..c4f53e834a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using MultiplyMultiply = element_wise::MultiplyMultiply; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto v1 = BlockGemmPipelineVersion::v1; +static constexpr auto v2 = BlockGemmPipelineVersion::v2; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + // N 256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + // N 512 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 128, 16, 16, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 32, 32, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 512, 16, 16, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + // N 256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 256, 16, 16, 32, 32, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 256, 16, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 512, 16, 16, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 512, 16, 16, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + // N 512 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 256, 16, 16, 32, 32, 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 256, 16, 16, 32, 32, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances = + std::tuple< +// clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances = + std::tuple< +// clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 32, 32, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 32, 32, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 32, 32, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 32, 32, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 32, 32, 6, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; + +template +using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances = + std::tuple< +// clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + // 256x[64, 256, 32]x128 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 8, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 224x[64, 256, 32]x128 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 192x[64, 256, 32]x128, 192x[64]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 6, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 160x[64, 256, 32]x128, 160x[64, 96, 32]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 5, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // 128x[64, 256, 32]x128, 128x[64, 128, 32]x256 + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance.cpp new file mode 100644 index 0000000000..c123a0fdd8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp new file mode 100644 index 0000000000..cb15688e6f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances< + v2, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp new file mode 100644 index 0000000000..c5a8448b59 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances< + v1, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp new file mode 100644 index 0000000000..c9ab9c1071 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances< + v1, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp new file mode 100644 index 0000000000..bb83bacb35 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances< + v2, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp new file mode 100644 index 0000000000..fb43347ceb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances< + v1, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp new file mode 100644 index 0000000000..c8ff03d6ef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( + std::vector, + Row, + F8, + F8, + Tuple, + F16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances< + v2, + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp index a97953de35..d9f9969621 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -44,7 +44,9 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu #if defined(CK_USE_AMD_MFMA_GFX950) #endif // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 8, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, @@ -54,7 +56,9 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 8, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, @@ -86,7 +90,9 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 8, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 8, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 8, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 8, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 8, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index e9a6fe313d..0619a98cf0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -43,13 +43,17 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tu //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #if defined(CK_USE_AMD_MFMA_GFX950) #endif // defined(CK_USE_AMD_MFMA_GFX950) +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> @@ -66,11 +70,10 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tup #if defined(CK_USE_AMD_MFMA_GFX950) #endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 128, 8, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, @@ -80,8 +83,10 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp index 918ef57a11..3edbd28cd8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -44,16 +44,21 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tu #if defined(CK_USE_AMD_MFMA_GFX950) #endif // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // AGPR Spill // DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // AGPR Spill when use permuted lds layout. so, use padding for these two. DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 192, 64, 8, 8, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> @@ -70,11 +75,12 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances = std::tup #if defined(CK_USE_AMD_MFMA_GFX950) #endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 8, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 128, 8, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 4, 4, 16, 16, 4, 1, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, @@ -84,8 +90,10 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_instances = std::tup DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp index bffa5db2d4..e3f3afff46 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -47,11 +47,12 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> @@ -68,15 +69,10 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< #if defined(CK_USE_AMD_MFMA_GFX950) #endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 4, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 2, 2, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 128, 8, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 2, 2, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp index 6c21aeb573..e39c9a63b9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp @@ -57,11 +57,13 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 16, 16, 8, 8, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#if !defined(CK_USE_AMD_MFMA_GFX950) // AGPR Spill // DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // AGPR Spill when use permuted lds layout. so, use padding for these two. DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, @@ -80,14 +82,11 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< #if defined(CK_USE_AMD_MFMA_GFX950) #endif // defined(CK_USE_AMD_MFMA_GFX950) // Latency friendly - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 8, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 128, 8, 8, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 4, 4, 32, 32, 2, 1, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, @@ -105,9 +104,11 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp index 2650202a64..43dc6be076 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp @@ -47,7 +47,9 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< // Disable due to test failure // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp index 5e278de812..9bdb2f51c2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp @@ -44,7 +44,9 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< #endif // defined(CK_USE_AMD_MFMA_GFX950) // Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp index 8666cf8589..616133e1ba 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -47,9 +47,11 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 8, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 4, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp index f5e801c167..6388c13444 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -47,8 +47,11 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 192, 128, 16, 16, 32, 32, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp index 21cef335c5..7fb690c8b2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -50,7 +50,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index e36b7f3093..68c6ce89ab 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -50,7 +50,9 @@ using device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp index ef1808c551..928c325ab7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -50,7 +50,9 @@ using device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_instances = std DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp index 763ac4facf..2a02995827 100755 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp @@ -48,7 +48,9 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp index 7a59823d9a..425c2c0391 100755 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp @@ -51,11 +51,18 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = // AGPR Spill // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // AGPR Spill when use permuted lds layout. so, use padding for these two. +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 16, 1, 16>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on >; @@ -85,6 +92,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 16, 1, 16>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp index 9e22c8f992..a685c4f252 100755 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -58,7 +58,9 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 16, 16, 8, 8, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#if !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, +#endif // !defined(CK_USE_AMD_MFMA_GFX950) DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 77a2097817..ac393560b7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -7,6 +7,25 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp index 69e22dee41..abff222a4e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" @@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_in // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances< + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_generic_instances< 2, GNHWC, GKYXC, @@ -34,7 +34,7 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_in // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances< + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_generic_instances< 2, GNHWC, GKYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp new file mode 100644 index 0000000000..5282d0da4c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index abb90b7c71..57df822021 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" @@ -25,19 +25,20 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2, - GNHWC, - GKYXC, - GNHWK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_generic_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances< - 2, - GNHWC, - GKYXC, - GNHWK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..715789747d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp new file mode 100644 index 0000000000..b93a4d0d97 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 2e0fef9cf8..c899f34d51 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" @@ -25,19 +25,20 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2, - GNHWC, - GKYXC, - GNHWK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances< - 2, - GNHWC, - GKYXC, - GNHWK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..026ab88158 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp new file mode 100644 index 0000000000..2c494f22ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp new file mode 100644 index 0000000000..911751b0b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp new file mode 100644 index 0000000000..3b1fdd9be7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp new file mode 100644 index 0000000000..9e717ea9c1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp new file mode 100644 index 0000000000..2081bc545f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp new file mode 100644 index 0000000000..42927a91e8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instances( + std::vector>>& instances) +{ + + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp new file mode 100644 index 0000000000..15366961f3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instances( + std::vector>>& instances) +{ + + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp new file mode 100644 index 0000000000..b38779de9d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); + ; +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp new file mode 100644 index 0000000000..18a6d326b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); + ; +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp new file mode 100644 index 0000000000..485eb85199 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp new file mode 100644 index 0000000000..fae04b5718 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp new file mode 100644 index 0000000000..ba82f403fb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp new file mode 100644 index 0000000000..4e2d37576d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp new file mode 100644 index 0000000000..f99c644618 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp new file mode 100644 index 0000000000..3be56da25b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 3d86949f79..860e08cafe 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -7,22 +7,37 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp ) if(DL_KERNELS) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp index 81d64344f7..752f8a8157 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" @@ -24,7 +24,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16 // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances< + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_generic_instances< 3, GNDHWC, GKZYXC, @@ -33,7 +33,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16 // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances< + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_generic_instances< 3, GNDHWC, GKZYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp index d03f0a7bac..4137fcb536 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" @@ -24,19 +24,20 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_generic_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances< - 3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_generic_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp index 7c24cc8fd4..d838a4e821 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" @@ -24,19 +24,20 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances< - 3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp new file mode 100644 index 0000000000..553d3bfc36 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp new file mode 100644 index 0000000000..3cadf1d99c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp new file mode 100644 index 0000000000..71e35c86a9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp new file mode 100644 index 0000000000..78a87d060e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp new file mode 100644 index 0000000000..18d6871201 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp new file mode 100644 index 0000000000..14d01f0702 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp new file mode 100644 index 0000000000..ab0f977258 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp new file mode 100644 index 0000000000..bdc2248ed0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp new file mode 100644 index 0000000000..ea3b9264b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp new file mode 100644 index 0000000000..1bcde328b0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp new file mode 100644 index 0000000000..6c4c53e5f3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp new file mode 100644 index 0000000000..97cc06cbad --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp new file mode 100644 index 0000000000..fd558f70ca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp new file mode 100644 index 0000000000..945e6fa563 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp new file mode 100644 index 0000000000..b7e6454062 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp index 29a645e9d7..dbfddeb8a4 100644 --- a/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp @@ -188,7 +188,8 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, // profile device GEMM instances for(auto& op_ptr : op_ptrs) { - std::vector kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; + // Seems like when performance measurement has bug when spiltK is large + std::vector kbatch_list = {1, 2, 4, 8, 16}; if(KBatch > 0) { @@ -232,7 +233,26 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, { c_device_buf.FromDevice(e_m_n_device_result.mData.data()); - pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8 + // set softer tolerances for fp8 + if constexpr((is_same_v || is_same_v || + is_same_v) || + (is_same_v || is_same_v || + is_same_v)) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-1; + double atol = 1e-1; + pass = pass & ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); + } + else + { +#endif + pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8 + } +#endif if(do_log) { @@ -249,6 +269,10 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, std::string op_name = op_ptr->GetTypeString(); + // timer of develop branch should only apply to empty hipstream + // hipStream_t stream; + // hip_check_error(hipStreamCreate(&stream)); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, @@ -271,27 +295,6 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl; -#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8 - // set softer tolerances for fp8 - if constexpr((is_same_v || is_same_v || - is_same_v) || - (is_same_v || is_same_v || - is_same_v)) - { - std::string msg = "Error: Incorrect results!"; - double rtol = 1e-1; - double atol = 1e-1; - pass = pass & ck::utils::check_err( - e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); - } - else - { -#endif - pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); -#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8 - } -#endif - if(tflops > best_tflops && ave_time > 1e-10) { best_op_name = op_name; diff --git a/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp new file mode 100644 index 0000000000..177e652cc3 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp @@ -0,0 +1,396 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +void preShuffleBuffer(const InOutDataType* src, InOutDataType* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} + +template +bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE, + int KBatch, + int n_warmup, + int n_iter, + uint64_t rotating = 0) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_preshuffled_mfma16( + f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size + Tensor b_preshuffled_mfma32( + f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + int total_gemm_needed = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() + + d0_m_n.GetElementSpaceSizeInBytes() + d1_m_n.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + preShuffleBuffer(b_k_n.mData.data(), b_preshuffled_mfma16.mData.data(), N, K, 16); + preShuffleBuffer(b_k_n.mData.data(), b_preshuffled_mfma32.mData.data(), N, K, 32); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = MultiplyMultiply; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_mfma16(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_mfma32(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf_mfma16.ToDevice(b_preshuffled_mfma16.mData.data()); + b_device_buf_mfma32.ToDevice(b_preshuffled_mfma32.mData.data()); + + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDSplitKBPreShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CElementOp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + c_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + int NPerXdl = op_ptr->GetPreShuffleParameters(); + + std::vector kbatch_list = {1, 2, 4, 8}; + + if(KBatch > 0) + { + kbatch_list = {KBatch}; + } + + for(std::size_t i = 0; i < kbatch_list.size(); i++) + { + auto kbatch_curr = kbatch_list[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(NPerXdl == 16 ? b_device_buf_mfma16.GetDeviceBuffer() + : b_device_buf_mfma32.GetDeviceBuffer()), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + kbatch_curr, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + + if(do_verification) + { + c_device_buf.FromDevice(e_m_n_device_result.mData.data()); + +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8 + // set softer tolerances for fp8 + if constexpr((is_same_v || is_same_v || + is_same_v) || + (is_same_v || is_same_v || + is_same_v)) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-3; + double atol = 5e-2; + pass = pass & ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); + } + else + { +#endif + pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8 + } +#endif + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_host : ", e_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", e_m_n_device_result.mData, ",") + << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, + time_kernel, + 0, + n_warmup, + n_iter, + rotating_count > 1, + rotating_count}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops && ave_time > 1e-10) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideE = " << StrideE << " KBatch = " << best_kbatch + << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index d2d7753fb4..5ed28b9826 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -50,6 +50,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_weight_preshuffle.cpp) list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) endif() list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) @@ -139,6 +140,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_weight_preshuffle_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) endif() target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) diff --git a/profiler/src/profile_gemm_multiply_multiply.cpp b/profiler/src/profile_gemm_multiply_multiply.cpp index 24c8630594..ad2bb77544 100644 --- a/profiler/src/profile_gemm_multiply_multiply.cpp +++ b/profiler/src/profile_gemm_multiply_multiply.cpp @@ -28,7 +28,8 @@ enum struct GemmDataType F16_F16_F16_F8, // 6 F8_F8_BF16, // 7 INT8_INT8_BF16, // 8 - F8_F8_F16, // 9 + INT8_INT8_F16, // 9 + F8_F8_F16, // 10 }; #define OP_NAME "gemm_multiply_multiply" @@ -177,6 +178,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) return profile( I8{}, I8{}, I8{}, I32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{}); } + else if(data_type == GemmDataType::INT8_INT8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile( + I8{}, I8{}, I8{}, I32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{}); + } else { std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_gemm_multiply_multiply_weight_preshuffle.cpp b/profiler/src/profile_gemm_multiply_multiply_weight_preshuffle.cpp new file mode 100644 index 0000000000..ee3be398e5 --- /dev/null +++ b/profiler/src/profile_gemm_multiply_multiply_weight_preshuffle.cpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_MFMA_MN, // 0 +}; + +enum struct GemmDataType +{ + F8_F8_F16, // 0 + F8_F8_BF16, // 1 +}; + +#define OP_NAME "gemm_multiply_multiply_weight_preshuffle" +#define OP_DESC "GEMM_Multiply_Multiply_Weight_PreShuffle" + +int profile_gemm_multiply_multiply_weight_preshuffle(int argc, char* argv[]) +{ + if(argc != 16 && argc != 20) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: f8->f16; 1: f8->bf16;\n"); + printf("arg3: matrix layout (0: A[m, k] * B[MFMA] = C[m, n];\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + printf("optional:\n"); + printf("arg16: number of kbatch (default 1)\n"); + printf("arg17: number of warm-up cycles (default 1)\n"); + printf("arg18: number of iterations (default 10)\n"); + printf("arg19: memory for rotating buffer (default 0, size in MB)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideD1 = std::stoi(argv[14]); + const int StrideE = std::stoi(argv[15]); + + int n_warmup = 1; + int n_iter = 10; + uint64_t rotating = 0; + int KBatch = 1; + if(argc == 20) + { + KBatch = std::stoi(argv[16]); + n_warmup = std::stoi(argv[17]); + n_iter = std::stoi(argv[18]); + rotating = std::stoull(argv[19]) * 1024 * 1024; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + using F8 = ck::f8_t; + // using I8 = int8_t; + // using I32 = int; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto comp_type, + auto acc_type, + auto d0_type, + auto d1_type, + auto c_type, + auto a_layout, + auto b_layout, + auto d0_layout, + auto d1_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using ComputeDataType = decltype(comp_type); + using D0DataType = decltype(d0_type); + using D1DataType = decltype(d1_type); + using AccDataType = decltype(acc_type); + using EDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using D0Layout = decltype(d0_layout); + using D1Layout = decltype(d1_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideD1 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = + ck::profiler::profile_gemm_multiply_multiply_weight_preshuffle_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, + (StrideE < 0) ? DefaultStrideE : StrideE, + KBatch, + n_warmup, + n_iter, + rotating); + + return pass ? 0 : 1; + }; + if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_MFMA_MN) + { + return profile( + F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_MFMA_MN) + { + return profile( + F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_multiply_multiply_weight_preshuffle); diff --git a/profiler/src/profile_gemm_universal_streamk.cpp b/profiler/src/profile_gemm_universal_streamk.cpp index a94bb866f2..b0f66a0c73 100644 --- a/profiler/src/profile_gemm_universal_streamk.cpp +++ b/profiler/src/profile_gemm_universal_streamk.cpp @@ -56,6 +56,26 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) exit(1); } + int M; + int N; + int StrideA; + int StrideB; + // Analyze the unsupported matrix shapes, switch the M and N number + if(std::stoi(argv[9]) % 8 != 0 && std::stoi(argv[8]) % 8 == 0) + { + M = std::stoi(argv[9]); + StrideA = std::stoi(argv[12]); + N = std::stoi(argv[8]); + StrideB = std::stoi(argv[11]); + } + else + { + M = std::stoi(argv[8]); + StrideA = std::stoi(argv[11]); + N = std::stoi(argv[9]); + StrideB = std::stoi(argv[12]); + } + const auto data_type = static_cast(std::stoi(argv[2])); const auto layout = static_cast(std::stoi(argv[3])); const bool do_verification = std::stoi(argv[4]); @@ -63,12 +83,8 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) const bool do_log = std::stoi(argv[6]); const bool time_kernel = std::stoi(argv[7]); - const int M = std::stoi(argv[8]); - const int N = std::stoi(argv[9]); const int K = std::stoi(argv[10]); - const int StrideA = std::stoi(argv[11]); - const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); const int Streamk_sel = std::stoi(argv[14]); const int Grid_size = std::stoi(argv[15]); diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 6089fc7a7e..cdf407d6cd 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -17,7 +17,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \ diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index 313b5ba4ca..063e0248e7 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -9,6 +9,10 @@ add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_c if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility) endif() +add_gtest_executable(test_grouped_convnd_bwd_weight_v3_interface_xdl test_grouped_convnd_bwd_weight_v3_interface_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_weight_v3_interface_xdl PRIVATE utility) +endif() add_gtest_executable(test_grouped_convnd_bwd_weight_interface_wmma test_grouped_convnd_bwd_weight_interface_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface_wmma PRIVATE utility) diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_v3_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_v3_interface_xdl.cpp new file mode 100644 index 0000000000..1556f15898 --- /dev/null +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_v3_interface_xdl.cpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#include + +using F16 = ck::half_t; +using F32 = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +using S = ck::Sequence; +using ConvolutionBackwardWeightSpecialization = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + +static constexpr auto ConvBwdWeightDefault = ConvolutionBackwardWeightSpecialization::Default; +static constexpr auto Filter1x1Stride1Pad0 = + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +class TestGroupedConvndBwdWeight : public ::testing::Test +{ + protected: + static constexpr ck::index_t NDimSpatial = 2; + + using InLayout = std::tuple_element_t<2, Tuple>; + using WeiLayout = std::tuple_element_t<1, Tuple>; + using OutLayout = std::tuple_element_t<0, Tuple>; + + // clang-format off + using GroupedConvBwdWeightDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 + //##########| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //##########| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //##########| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + < NDimSpatial, InLayout, WeiLayout,OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 32, 32, 1, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 2, 8, 1, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>; + // clang-format on + + ck::utils::conv::ConvParam conv_param; + ck::index_t split_k{2}; + + template + bool Run() + { + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + std::array input_lengths{}; + std::array filter_lengths{}; + std::array output_lengths{}; + std::array input_strides{}; + std::array weights_strides{}; + std::array output_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; + + range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths)); + range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); + range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths)); + range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides)); + range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths)); + range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); + range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); + range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); + range_copy(conv_param.input_left_pads_, begin(input_left_pads)); + range_copy(conv_param.input_right_pads_, begin(input_right_pads)); + + auto conv = GroupedConvBwdWeightDeviceInstance{}; + + auto argument = conv.MakeArgument(nullptr, + nullptr, + nullptr, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + split_k); + return conv.IsSupportedArgument(argument); + } +}; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using NHWGC = ck::tensor_layout::convolution::NHWGC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using NHWGK = ck::tensor_layout::convolution::NHWGK; + +using KernelTypes = + ::testing::Types, std::tuple>; + +template +class TestGroupedConvndBwdWeightV3Default + : public TestGroupedConvndBwdWeight +{ +}; + +template +class TestGroupedConvndBwdWeightV3Filter1x1 + : public TestGroupedConvndBwdWeight +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdWeightV3Default, KernelTypes); +TYPED_TEST_SUITE(TestGroupedConvndBwdWeightV3Filter1x1, KernelTypes); + +TYPED_TEST(TestGroupedConvndBwdWeightV3Filter1x1, SpecializationCheck) +{ + // Check filter 3,3 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check strides 2,2 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check with pad + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Not supported version + this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} + +TYPED_TEST(TestGroupedConvndBwdWeightV3Default, VectorLoadCheck) +{ + // vector load for A + this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + // vector load for B, E, Ds + this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index cc612794f4..f65e89bb82 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init) constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; - const auto mx_mfma_kernel = ck::matmul; + const auto mfma_kernel = ck::matmul; bool pass = true; - pass = ck::mfma_test::TestMFMA{}(mx_mfma_kernel, init); + BLOCK_K>{}(mfma_kernel, init); return pass; } TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 0; + auto AB_init = 4; auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP8MFMA32x32x64) { - auto AB_init = 0; + auto AB_init = 4; auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } + +/** + * @brief Run the test for the given MX MFMA instruction + * + * @param init - selects initialization algorithm for A and B tensors + */ +template +bool run_mxmfma_test(ck::index_t init) +{ + static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || + mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, + "Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported"); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + using AccType = float; // only MFMA_F32 instructions supported + using ScaleType = ck::e8m0_bexp_t; // biased exponent type + + ck::mfma_type(mfma)> mfma_instr; + constexpr auto BLOCK_M = mfma_instr.m_per_blk; + constexpr auto BLOCK_N = mfma_instr.n_per_blk; + constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; + constexpr auto BLOCK_X = 32; // scaling vector size + + const auto mx_mfma_kernel = + ck::matmul; + + bool pass = true; + + pass = ck::mxmfma_test::TestMXMFMA{}(mx_mfma_kernel, init); + + return pass; +} + +TEST(MXMFMA, MXFP8MFMA16x16x128) +{ + auto AB_init = 7; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP8MFMA32x32x64) +{ + auto AB_init = 7; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index e96e1b0b29..1f9091ebc5 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck/ck.hpp" @@ -7,7 +10,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" #include "ck/library/utility/check_err.hpp" namespace ck { @@ -18,7 +21,13 @@ enum class MFMA_F8F6F4 F32_16x16x128 = static_cast(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4 F32_32x32x64 = - static_cast(MfmaInstr::mfma_f32_32x32x64f8f6f4) // V_MFMA_F32_32X32X64_F8F6F4 + static_cast(MfmaInstr::mfma_f32_32x32x64f8f6f4), // V_MFMA_F32_32X32X64_F8F6F4 + + SCALE_F32_16x16x128 = static_cast( + MfmaInstr::mfma_scale_f32_16x16x128f8f6f4), // V_MFMA_SCALE_F32_16X16X128_F8F6F4 + SCALE_F32_32x32x64 = static_cast( + MfmaInstr::mfma_scale_f32_32x32x64f8f6f4) // V_MFMA_SCALE_F32_32X32X64_F8F6F4 + }; template @@ -32,6 +41,17 @@ struct mfma_type_selector auto op = mfma_type{}; op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); } + + __device__ void operator()(AFragT const& fragA, + const int32_t scale_a, + BFragT const& fragB, + const int32_t scale_b, + AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<16, 16, AFragT, BFragT, AccumFragT>( + fragA, scale_a, fragB, scale_b, fragAcc); + } }; template @@ -42,6 +62,17 @@ struct mfma_type_selector auto op = mfma_type{}; op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); } + + __device__ void operator()(AFragT const& fragA, + const int32_t scale_a, + BFragT const& fragB, + const int32_t scale_b, + AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<32, 32, AFragT, BFragT, AccumFragT>( + fragA, scale_a, fragB, scale_b, fragAcc); + } }; template @@ -52,151 +83,428 @@ static constexpr int32_t vectorSize(const VecT&) // Define a load function for input A blocks: // Size: (BLOCK_M x BLOCK_K) -// ASSUMPTION: -// - We want contiguous BLOCK_M sized column neighbors in register. -// - Data is in col_major format -// This means: -// - From A we will load K columns of size BLOCK_M to satisfy our input data +// - Data is in column major format +// - Rows are loaded in contiguous chunks that map to corresponding microscales +// - Each row is loaded in chunks of size 16 and each thread loads 32 elements template __device__ AFragT load_A_col_major(AType const* input_ptr) { // clang-format off // Register Mapping for 16x128: || Register Mapping for 32x64: - // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | - // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 | - // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector - // Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element - // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0] - // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1] - // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2] - // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3] - // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4] - // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5] - // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6] - // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7] - // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8] - // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9] - // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10] - // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11] - // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12] - // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13] - // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14] - // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15] - // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16] - // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17] - // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18] - // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19] - // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20] - // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21] - // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22] - // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23] - // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24] - // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25] - // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26] - // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27] - // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28] - // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29] - // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30] - // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31] + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | + // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | + // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | + // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | + // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | + // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | + // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | + // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | + // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | + // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | + // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | + // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | + // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | + // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | + // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | + // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | + // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | + // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | + // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | + // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | + // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | + // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | + // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | + // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | + // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | + // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | + // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | + // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | + // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | + // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | + // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | + // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | // clang-format on - // Here we want to load a BLOCK_M x BLOCK_K block of data. - static constexpr uint32_t VW = vectorSize(AFragT{}); - using ARawT = typename scalar_type::type; - using AScalarFragT = vector_type::type; + static constexpr int32_t WAVE_SIZE = 64; + + // Here we want to load from rows of A in chunks of 16 elements each. + static constexpr uint32_t chunk_size = 16; + + // each chunk is separated by offset + static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M; // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row - (threadIdx.x / BLOCK_M) * VW); // Col - auto stepCoord2D = std::make_pair(0u, 1u); + auto startCoord2D = + std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15} + (threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48} + + auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows + auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; // BLOCK_M is a stride in A matrix - auto startOffset = col_major(startCoord2D, BLOCK_M); - auto kOffset = col_major(stepCoord2D, BLOCK_M); + auto startOffset = col_major(startCoord2D, BLOCK_M); + auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M); + auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M); - // kOffset == BLOCK_M - // This means every BLOCK_M element is loaded into output vector - auto fragA = AScalarFragT{}; -#pragma unroll VW - for(uint32_t i = 0; i < VW; i++) + using ARawT = typename scalar_type::type; + using AScalarFragT = vector_type::type; + + AScalarFragT fragA{}; + +#pragma unroll + for(int chunk = 0; chunk < 2; chunk++) { - fragA[i] = bit_cast(input_ptr[startOffset + i * kOffset]); +#pragma unroll + for(uint32_t i = 0; i < chunk_size; i++) + { + fragA[chunk * chunk_size + i] = + bit_cast(input_ptr[startOffset + chunk * kMajorOffset + i * kMinorOffset]); + } } return fragA; } +// Define a load function for input A blocks: +// Size: (BLOCK_M x BLOCK_K) +// - Data is in row major format +// - Rows are loaded in contiguous chunks that map to corresponding microscales +// - Each row is loaded in chunks of size 16 and each thread loads 32 elements +template +__device__ AFragT load_A_row_major(AType const* input_ptr) +{ + // clang-format off + // Register Mapping for 16x128: || Register Mapping for 32x64: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | + // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | + // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | + // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | + // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | + // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | + // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | + // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | + // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | + // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | + // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | + // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | + // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | + // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | + // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | + // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | + // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | + // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | + // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | + // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | + // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | + // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | + // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | + // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | + // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | + // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | + // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | + // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | + // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | + // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | + // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | + // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | + // clang-format on + + static constexpr int32_t WAVE_SIZE = 64; + + // Here we want to load from rows of A in chunks of 16 elements each. + static constexpr uint32_t chunk_size = 16; + + // each chunk is separated by offset + static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M; + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 32 elements. + // We need to know where they start, and where the next elements are. + auto startCoord2D = + std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15} + (threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48} + + // auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows + auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + // BLOCK_K is a stride in A matrix + auto startOffset = row_major(startCoord2D, BLOCK_K); + // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K); + auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K); + + using ARawT = typename scalar_type::type; + using AScalarFragT = vector_type::type; + + union + { + AFragT frag; + AScalarFragT chunks[2]; + } fragA{}; + + auto* fragPtr = reinterpret_cast(input_ptr + startOffset); + fragA.chunks[0] = *fragPtr; + fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); + fragA.chunks[1] = *fragPtr; + + return fragA.frag; +} + +// Define a load function for scaled A blocks: +// Size: (BLOCK_M x BLOCK_K) +// ASSUMPTION: +// - The scale inputs distributed across 64 lanes. +template +__device__ AFragT load_mx_A_row_major(AType const* input_ptr, + ScaleType const* scale_ptr, + ScaleFragT& fragX) +{ + // clang-format off + // Register Mapping for 16x128: || Register Mapping for 32x64: + // Size | BLOCK_M | BLOCK_M | | BLOCK_M | BLOCK_M | | || Size | BLOCK_M | BLOCK_M | | | + // M | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || M | 0 ... 31 | 0 ... 31 | Vector | | + // Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale | + // Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------| + // Reg 0 [0:7] | K0 | K16 | x(M,0) | K32 | K48 | x(M,1) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(M,0) | + // Reg 0 [8:15] | K1 | K17 | x(M,0) | K33 | K49 | x(M,1) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(M,0) | + // Reg 0 [16:23] | K2 | K18 | x(M,0) | K34 | K50 | x(M,1) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(M,0) | + // Reg 0 [24:31] | K3 | K19 | x(M,0) | K35 | K51 | x(M,1) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(M,0) | + // Reg 1 [0:7] | K4 | K20 | x(M,0) | K36 | K52 | x(M,1) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(M,0) | + // Reg 1 [8:15] | K5 | K21 | x(M,0) | K37 | K53 | x(M,1) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(M,0) | + // Reg 1 [16:23] | K6 | K22 | x(M,0) | K38 | K54 | x(M,1) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(M,0) | + // Reg 1 [24:31] | K7 | K23 | x(M,0) | K39 | K55 | x(M,1) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(M,0) | + // Reg 2 [0:7] | K8 | K24 | x(M,0) | K40 | K56 | x(M,1) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(M,0) | + // Reg 2 [8:15] | K9 | K25 | x(M,0) | K41 | K57 | x(M,1) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(M,0) | + // Reg 2 [16:23] | K10 | K26 | x(M,0) | K42 | K58 | x(M,1) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(M,0) | + // Reg 2 [24:31] | K11 | K27 | x(M,0) | K43 | K59 | x(M,1) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(M,0) | + // Reg 3 [0:7] | K12 | K28 | x(M,0) | K44 | K60 | x(M,1) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(M,0) | + // Reg 3 [8:15] | K13 | K29 | x(M,0) | K45 | K61 | x(M,1) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(M,0) | + // Reg 3 [16:23] | K14 | K30 | x(M,0) | K46 | K62 | x(M,1) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(M,0) | + // Reg 3 [24:31] | K15 | K31 | x(M,0) | K47 | K63 | x(M,1) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(M,0) | + // Reg 4 [0:7] | K64 | K80 | x(M,2) | K96 | K112 | x(M,3) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(M,1) | + // Reg 4 [8:15] | K65 | K81 | x(M,2) | K97 | K113 | x(M,3) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(M,1) | + // Reg 4 [16:23] | K66 | K82 | x(M,2) | K98 | K114 | x(M,3) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(M,1) | + // Reg 4 [24:31] | K67 | K83 | x(M,2) | K99 | K115 | x(M,3) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(M,1) | + // Reg 5 [0:7] | K68 | K84 | x(M,2) | K100 | K116 | x(M,3) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(M,1) | + // Reg 5 [8:15] | K69 | K85 | x(M,2) | K101 | K117 | x(M,3) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(M,1) | + // Reg 5 [16:23] | K70 | K86 | x(M,2) | K102 | K118 | x(M,3) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(M,1) | + // Reg 5 [24:31] | K71 | K87 | x(M,2) | K103 | K119 | x(M,3) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(M,1) | + // Reg 6 [0:7] | K72 | K88 | x(M,2) | K104 | K120 | x(M,3) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(M,1) | + // Reg 6 [8:15] | K73 | K89 | x(M,2) | K105 | K121 | x(M,3) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(M,1) | + // Reg 6 [16:23] | K74 | K90 | x(M,2) | K106 | K122 | x(M,3) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(M,1) | + // Reg 6 [24:31] | K75 | K91 | x(M,2) | K107 | K123 | x(M,3) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(M,1) | + // Reg 7 [0:7] | K76 | K92 | x(M,2) | K108 | K124 | x(M,3) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(M,1) | + // Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) | + // Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) | + // Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) | + // clang-format on + static constexpr uint32_t VW = vectorSize(AFragT{}); + static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 1 element + // We need to know where they start + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row + (threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + // BLOCK_K / BLOCK_X is a stride in xA matrix + auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X); + + // obtain 8-bit exponent + fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + + return load_A_row_major(input_ptr); +} + // Define a load function for input B blocks: // Size: (BLOCK_K x BLOCK_N) -// ASSUMPTION: -// - We want contiguous BLOCK_N sized row neighbors in register. -// - Data is in row_major format -// This means: -// - From B we will load K rows of size BLOCK_N to satisfy our input data +// - Data is in col major format +// - Cols are loaded in contiguous chunks that map to corresponding microscales +// - Each col is loaded in chunks of size 16 and each thread loads 32 elements template __device__ BFragT load_B_col_major(BType const* input_ptr) { // clang-format off // Register Mapping for 128x16: || Register Mapping for 64x32: - // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | - // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || N | 0 ... 31 | 0 ... 31 | - // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector - // Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element - // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0] - // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1] - // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2] - // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3] - // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4] - // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5] - // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6] - // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7] - // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8] - // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9] - // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10] - // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11] - // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12] - // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13] - // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14] - // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15] - // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16] - // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17] - // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18] - // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19] - // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20] - // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21] - // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22] - // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23] - // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24] - // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25] - // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26] - // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27] - // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28] - // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29] - // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30] - // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31] + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | + // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | + // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | + // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | + // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | + // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | + // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | + // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | + // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | + // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | + // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | + // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | + // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | + // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | + // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | + // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | + // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | + // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | + // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | + // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | + // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | + // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | + // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | + // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | + // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | + // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | + // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | + // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | + // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | + // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | + // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | + // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | // clang-format on - // Here we want to load a BLOCK_K x BLOCK_N block of data. - static constexpr uint32_t VW = vectorSize(BFragT{}); + static constexpr int32_t WAVE_SIZE = 64; + + // Here we want to load from cols of B in chunks of 16 elements each. + static constexpr uint32_t chunk_size = 16; + + // each chunk is separated by an offset + static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64 // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row - threadIdx.x % BLOCK_N); // Col + auto startCoord2D = + std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48} + threadIdx.x % BLOCK_N); // Col {0-31} | {0-15} // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; - auto startOffset = col_major(startCoord2D, BLOCK_K); + // auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols + auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col - auto const* fragPtr = reinterpret_cast(input_ptr + startOffset); - return *fragPtr; + // BLOCK_K is a stride in B matrix + auto startOffset = col_major(startCoord2D, BLOCK_K); + // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K); + auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K); + + using BRawT = typename scalar_type::type; + using BScalarFragT = vector_type::type; + + union + { + BFragT frag; + BScalarFragT chunks[2]; + } fragB{}; + + auto* fragPtr = reinterpret_cast(input_ptr + startOffset); + fragB.chunks[0] = *fragPtr; + fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); + fragB.chunks[1] = *fragPtr; + + return fragB.frag; +} + +// Define a load function for scaled B blocks: +// Size: (BLOCK_K x BLOCK_N) +// ASSUMPTION: +// - The scale inputs distributed across 64 lanes. +template +__device__ BFragT load_mx_B_col_major(BType const* input_ptr, + ScaleType const* scale_ptr, + ScaleFragT& fragX) + +{ + // clang-format off + // Register Mapping for 128x16: || Register Mapping for 64x32: + // Size | BLOCK_N | BLOCK_N | | BLOCK_N | BLOCK_N | | || Size | BLOCK_N | BLOCK_N | | | + // N | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || N | 0 ... 31 | 0 ... 31 | Vector | | + // Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale | + // Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------| + // Reg 0 [0:7] | K0 | K16 | x(0,N) | K32 | K48 | x(1,N) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(0,N) | + // Reg 0 [8:15] | K1 | K17 | x(0,N) | K33 | K49 | x(1,N) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(0,N) | + // Reg 0 [16:23] | K2 | K18 | x(0,N) | K34 | K50 | x(1,N) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(0,N) | + // Reg 0 [24:31] | K3 | K19 | x(0,N) | K35 | K51 | x(1,N) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(0,N) | + // Reg 1 [0:7] | K4 | K20 | x(0,N) | K36 | K52 | x(1,N) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(0,N) | + // Reg 1 [8:15] | K5 | K21 | x(0,N) | K37 | K53 | x(1,N) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(0,N) | + // Reg 1 [16:23] | K6 | K22 | x(0,N) | K38 | K54 | x(1,N) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(0,N) | + // Reg 1 [24:31] | K7 | K23 | x(0,N) | K39 | K55 | x(1,N) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(0,N) | + // Reg 2 [0:7] | K8 | K24 | x(0,N) | K40 | K56 | x(1,N) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(0,N) | + // Reg 2 [8:15] | K9 | K25 | x(0,N) | K41 | K57 | x(1,N) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(0,N) | + // Reg 2 [16:23] | K10 | K26 | x(0,N) | K42 | K58 | x(1,N) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(0,N) | + // Reg 2 [24:31] | K11 | K27 | x(0,N) | K43 | K59 | x(1,N) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(0,N) | + // Reg 3 [0:7] | K12 | K28 | x(0,N) | K44 | K60 | x(1,N) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(0,N) | + // Reg 3 [8:15] | K13 | K29 | x(0,N) | K45 | K61 | x(1,N) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(0,N) | + // Reg 3 [16:23] | K14 | K30 | x(0,N) | K46 | K62 | x(1,N) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(0,N) | + // Reg 3 [24:31] | K15 | K31 | x(0,N) | K47 | K63 | x(1,N) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(0,N) | + // Reg 4 [0:7] | K64 | K80 | x(2,N) | K96 | K112 | x(3,N) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(1,N) | + // Reg 4 [8:15] | K65 | K81 | x(2,N) | K97 | K113 | x(3,N) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(1,N) | + // Reg 4 [16:23] | K66 | K82 | x(2,N) | K98 | K114 | x(3,N) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(1,N) | + // Reg 4 [24:31] | K67 | K83 | x(2,N) | K99 | K115 | x(3,N) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(1,N) | + // Reg 5 [0:7] | K68 | K84 | x(2,N) | K100 | K116 | x(3,N) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(1,N) | + // Reg 5 [8:15] | K69 | K85 | x(2,N) | K101 | K117 | x(3,N) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(1,N) | + // Reg 5 [16:23] | K70 | K86 | x(2,N) | K102 | K118 | x(3,N) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(1,N) | + // Reg 5 [24:31] | K71 | K87 | x(2,N) | K103 | K119 | x(3,N) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(1,N) | + // Reg 6 [0:7] | K72 | K88 | x(2,N) | K104 | K120 | x(3,N) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(1,N) | + // Reg 6 [8:15] | K73 | K89 | x(2,N) | K105 | K121 | x(3,N) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(1,N) | + // Reg 6 [16:23] | K74 | K90 | x(2,N) | K106 | K122 | x(3,N) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(1,N) | + // Reg 6 [24:31] | K75 | K91 | x(2,N) | K107 | K123 | x(3,N) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(1,N) | + // Reg 7 [0:7] | K76 | K92 | x(2,N) | K108 | K124 | x(3,N) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(1,N) | + // Reg 7 [8:15] | K77 | K93 | x(2,N) | K109 | K125 | x(3,N) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(1,N) | + // Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) | + // Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) | + + // clang-format on + static constexpr uint32_t VW = vectorSize(BFragT{}); + static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 1 element + // We need to know where to start + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row + threadIdx.x % BLOCK_N); // Col + + // Flatten to 1D col_major offsets. + auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; + + auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X); + + // obtain 8-bit exponent + fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + + return load_B_col_major(input_ptr); } // Define a store function for C @@ -309,6 +617,129 @@ struct store_C_col_major } }; +// Define a store function for C +// Size: (BLOCK_M x BLOCK_N) +// ASSUMPTION: +// - We want contiguous BLOCK_N sized row neighbors in register. +// - Data is in row major format +template +struct store_C_row_major; + +// Here we want to store a 16x16 block of data. +// +// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | +// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | +// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector +// Register Element ------------ ------------- ------------ -------------- Element +// Reg0 | M0 | M4 | M8 | M12 | v[0] +// Reg1 | M1 | M5 | M9 | M13 | v[1] +// Reg2 | M2 | M6 | M10 | M14 | v[2] +// Reg3 | M3 | M7 | M11 | M15 | v[3] +template +struct store_C_row_major +{ + __device__ void operator()(CType* output, CFragT cFrag) + { + static constexpr uint32_t VW = vectorSize(cFrag); // 4 + static constexpr uint32_t Dim = 16; + + // Each thread will load 4 elements. + // We need to know where they start, and where the next elements are. + auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row + threadIdx.x % Dim); // Col + auto stepCoord2D = std::make_pair(1u, 0u); + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + auto startOffset = row_major(startCoord2D, 16); + auto kOffset = row_major(stepCoord2D, 16); + + auto* fragPtr = reinterpret_cast(output + startOffset); + *fragPtr = cFrag; + + // If you notice carefully, kOffset != 1. + // This means the following is vector is updated with 4 non-contiguous offsets, + // which the compiler will separate into 4 different global_store_dword instructions. + output[startOffset] = cFrag[0]; // v[0] = Reg 0 + output[startOffset + kOffset] = cFrag[1]; // v[1] = Reg 1 + output[startOffset + 2 * kOffset] = cFrag[2]; // v[2] = Reg 2 + output[startOffset + 3 * kOffset] = cFrag[3]; // v[3] = Reg 3 + } +}; + +// Here we want to store a 32x32 block of data. +// Register Mapping: + +// Size | BLOCK_N | BLOCK_N | +// N | 0 ... 31 | 0 ... 31 | +// Thread Id | 0 ... 31 | 32 ... 63 | Vector +// Register Element ------------ ------------- Element +// Reg0 | M0 | M4 | v[0] +// Reg1 | M1 | M5 | v[1] +// Reg2 | M2 | M6 | v[2] +// Reg3 | M3 | M7 | v[3] +// ____________ _____________ +// Reg4 | M8 | M12 | v[4] +// Reg5 | M9 | M13 | v[5] +// Reg6 | M10 | M14 | v[6] +// Reg7 | M11 | M15 | v[7] +// ____________ _____________ +// Reg8 | M16 | M20 | v[8] +// Reg9 | M17 | M21 | v[9] +// Reg10 | M18 | M22 | v[10] +// Reg11 | M19 | M23 | v[11] +// ____________ _____________ +// Reg12 | M24 | M28 | v[12] +// Reg13 | M25 | M29 | v[13] +// Reg14 | M26 | M30 | v[14] +// Reg15 | M27 | M31 | v[15] + +template +struct store_C_row_major +{ + __device__ void operator()(CType* output, CFragT cFrag) + { + static constexpr uint32_t WAVE_SIZE = 64; + static constexpr uint32_t VW = 4; // This VW is per 'chunk' + static constexpr uint32_t Dim = 32; // BLOCK_N + static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8 + + auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row + threadIdx.x % Dim); // Col + + // Minor step for each 'chunk' + auto minorStepCoord2D = std::make_pair(1u, 0u); + + // Major step between 'chunks' + auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0); + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + auto startOffset = row_major(startCoord2D, 32); + auto kMinorOffset = row_major(minorStepCoord2D, 32); + auto kMajorOffset = row_major(majorStepCoord2D, 32); + + output[startOffset] = cFrag[0]; // v[0] = Reg 0 + output[startOffset + kMinorOffset] = cFrag[1]; // v[1] = Reg 1 + output[startOffset + 2 * kMinorOffset] = cFrag[2]; // v[2] = Reg 2 + output[startOffset + 3 * kMinorOffset] = cFrag[3]; // v[3] = Reg 3 + output[startOffset + kMajorOffset] = cFrag[4]; // v[4] = Reg 4 + output[startOffset + kMajorOffset + kMinorOffset] = cFrag[5]; // v[5] = Reg 5 + output[startOffset + kMajorOffset + 2 * kMinorOffset] = cFrag[6]; // v[6] = Reg 6 + output[startOffset + kMajorOffset + 3 * kMinorOffset] = cFrag[7]; // v[7] = Reg 7 + output[startOffset + 2 * kMajorOffset] = cFrag[8]; // v[8] = Reg 8 + output[startOffset + 2 * kMajorOffset + kMinorOffset] = cFrag[9]; // v[9] = Reg 9 + output[startOffset + 2 * kMajorOffset + 2 * kMinorOffset] = cFrag[10]; // v[10] = Reg 10 + output[startOffset + 2 * kMajorOffset + 3 * kMinorOffset] = cFrag[11]; // v[11] = Reg 11 + output[startOffset + 3 * kMajorOffset] = cFrag[12]; // v[12] = Reg 12 + output[startOffset + 3 * kMajorOffset + kMinorOffset] = cFrag[13]; // v[13] = Reg 13 + output[startOffset + 3 * kMajorOffset + 2 * kMinorOffset] = cFrag[14]; // v[14] = Reg 14 + output[startOffset + 3 * kMajorOffset + 3 * kMinorOffset] = cFrag[15]; // v[15] = Reg 15 + } +}; + template {}; storeC(c, fragC); } + +template +__global__ void +matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c) +{ + constexpr int WAVE_SIZE = 64; + assert(threadIdx.x < WAVE_SIZE); + assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); + + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; + using AccumFragT = vector_type; + using RawAccumFragT = vector_type::type; + using ScaleFragT = int32_t; + + // Create frags + auto fragA = AFragT{}; + auto fragB = BFragT{}; + auto fragC = CFragT{}; + auto fragAcc = AccumFragT{0}; + auto fragXa = ScaleFragT{0}; + auto fragXb = ScaleFragT{0}; + + // Load the inputs. + // A = col major, BLOCK_M x BLOCK_K + fragA = load_mx_A_row_major( + a, xa, fragXa); + + // B = col major, BLOCK_K x BLOCK_N + fragB = load_mx_B_col_major( + b, xb, fragXb); + + // Scaled Matrix multiply-accumulate using MFMA units + // Accumulation intermediate = BLOCK_M x BLOCK_N + mfma_type_selector{}( + fragA, fragXa, fragB, fragXb, fragAcc); + + for(int i = 0; i < vectorSize(fragC); ++i) + { + fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); + } + + auto storeC = store_C_row_major{}; + storeC(c, fragC); +} + /** * @brief Structure to hold dimension parameters for GEMM tensors. * @@ -373,6 +859,225 @@ struct GemmParams ck::index_t StrideC = -1; }; +namespace mxmfma_test { +template +void RunHostGEMM(const Tensor& A, + const Tensor& a_scales, + const Tensor& B, + const Tensor& b_scales, + Tensor& C) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + A, a_scales, B, b_scales, C, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(KernelType kernel, + const Tensor& A, + const Tensor& a_scales, + const Tensor& B, + const Tensor& b_scales, + Tensor& C) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem a_scales_device_buf(sizeof(ScaleType) * a_scales.mDesc.GetElementSpaceSize()); + DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem b_scales_device_buf(sizeof(ScaleType) * b_scales.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + a_scales_device_buf.ToDevice(a_scales.mData.data()); + b_n_k_device_buf.ToDevice(B.mData.data()); + b_scales_device_buf.ToDevice(b_scales.mData.data()); + kernel<<<1, 64>>>(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(a_scales_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(b_scales_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer())); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; +} + +template +struct TestMXMFMA +{ + auto PrepareGemmTensors(const GemmParams& params, index_t init) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor a_scales( + f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{})); + Tensor b_n_k( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor b_scales( + f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + switch(init) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_scales.GenerateTensorValue( + GeneratorTensor_1{ScaleType{0.015625f}}); // 1/64 + // NOTE: not all numbers are representable in FP8, BF8, etc. + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); + break; + case 1: + // results in C = {K} + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{512.0f}}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f / 512}}); + break; + case 2: + // expect small round off errors + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + + b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); + break; + + case 3: + // expect small round off errors + a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); + a_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + b_n_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); + b_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + break; + default: + // all initial values are representable in FP8, BF8 + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + a_scales.GenerateTensorValue( + GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] + b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + b_scales.GenerateTensorValue( + GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] + + break; + } + + return std::make_tuple( + a_m_k, a_scales, b_n_k, b_scales, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceMFMA& mfma_kernel, index_t init) + { + // Arrange + GemmParams params; + params.M = BLOCK_M; + params.N = BLOCK_N; + params.K = BLOCK_K; + + auto f_get_default_stride = [](std::size_t row, + std::size_t col, + ck::index_t stride, + auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{}); + params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{}); + params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{}); + + auto host_tensors = PrepareGemmTensors(params, init); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& a_scales = std::get<1>(host_tensors); + const Tensor& b = std::get<2>(host_tensors); + const Tensor& b_scales = std::get<3>(host_tensors); + Tensor& c_host = std::get<4>(host_tensors); + Tensor& c_device = std::get<5>(host_tensors); + + RunHostGEMM(a, a_scales, b, b_scales, c_host); + + RunDeviceGEMM(mfma_kernel, a, a_scales, b, b_scales, c_device); + + bool res = false; + if constexpr(std::is_same::value || + std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + } + else + { + std::cout << "UNSUPPORTED CDataType" << std::endl; + } + + return res; + } +}; + +} // namespace mxmfma_test + namespace mfma_test { template