From 0696f999375c966c748554c5daa8efd0bfeac366 Mon Sep 17 00:00:00 2001 From: apoorva Date: Fri, 13 Jun 2025 08:30:41 +0000 Subject: [PATCH 01/62] Added tests. --- test/gemm_add/CMakeLists.txt | 5 ++++ test/gemm_add/test_gemm_add_wmma.cpp | 40 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 test/gemm_add/test_gemm_add_wmma.cpp diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index f7430b8ae1..8f099f1286 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -22,3 +22,8 @@ add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp if(result EQUAL 0) target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance) endif() + +add_gtest_executable(test_gemm_add_wmma test_gemm_add_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_wmma PRIVATE utility device_gemm_add_instance) +endif() diff --git a/test/gemm_add/test_gemm_add_wmma.cpp b/test/gemm_add/test_gemm_add_wmma.cpp new file mode 100644 index 0000000000..a2238463f5 --- /dev/null +++ b/test/gemm_add/test_gemm_add_wmma.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAdd : public TestGemmD0Common +{ + private: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl; + + decltype(ProfileGemmAddImpl) GetImpl() override { return ProfileGemmAddImpl; } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); +TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); } From 30d65b9f81e7ff4349af17577babfc747adceb37 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 18 Jun 2025 08:30:04 +0000 Subject: [PATCH 02/62] Added gemm_add examples for wmma v1 and v3 --- example/68_gemm_add/CMakeLists.txt | 6 + example/68_gemm_add/gemm_add_wmma_bf16.cpp | 341 ++++++++++++++++++ example/68_gemm_add/gemm_add_wmma_fp16.cpp | 335 +++++++++++++++++ example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp | 66 ++++ example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp | 66 ++++ 5 files changed, 814 insertions(+) create mode 100644 example/68_gemm_add/CMakeLists.txt create mode 100644 example/68_gemm_add/gemm_add_wmma_bf16.cpp create mode 100644 example/68_gemm_add/gemm_add_wmma_fp16.cpp create mode 100644 example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp create mode 100644 example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp diff --git a/example/68_gemm_add/CMakeLists.txt b/example/68_gemm_add/CMakeLists.txt new file mode 100644 index 0000000000..cb0b1fc457 --- /dev/null +++ b/example/68_gemm_add/CMakeLists.txt @@ -0,0 +1,6 @@ +add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp) +add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp) +add_example_executable(example_gemm_add_wmma_v3_fp16 gemm_add_wmma_v3_fp16.cpp) +add_example_executable(example_gemm_add_wmma_v3_bf16 gemm_add_wmma_v3_bf16.cpp) + + diff --git a/example/68_gemm_add/gemm_add_wmma_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_bf16.cpp new file mode 100644 index 0000000000..e3bfbd74cf --- /dev/null +++ b/example/68_gemm_add/gemm_add_wmma_bf16.cpp @@ -0,0 +1,341 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const ck::bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 + x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(ck::bhalf_t& y, const ck::bhalf_t& x0, const ck::bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(ck::bhalf_t& y, const float& x0, const ck::bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 + x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 + x1; + }; +}; + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE" + "beta\n"); + exit(0); + } + + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/68_gemm_add/gemm_add_wmma_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_fp16.cpp new file mode 100644 index 0000000000..4d6b94ac29 --- /dev/null +++ b/example/68_gemm_add/gemm_add_wmma_fp16.cpp @@ -0,0 +1,335 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const ck::half_t& x1) const + { + y = x0 + ck::type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(ck::half_t& y, const float& x0, const float& x1) const + { + y = ck::type_convert(x0 + x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(ck::half_t& y, const float& x0, const ck::half_t& x1) const + { + y = ck::type_convert(x0) + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(ck::half_t& y, const ck::half_t& x0, const ck::half_t& x1) const + { + y = x0 + x1; + }; +}; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE" + "beta\n"); + exit(0); + } + + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp new file mode 100644 index 0000000000..4047566e9e --- /dev/null +++ b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +using device_gemm_add_bf16_bf16_bf16_bf16_mk_nk_mn_mn_generic_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +using device_gemm_add_wmma_universal_bf16_bf16_bf16_bf16_mk_nk_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmDefault, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_wmma_universal_bf16_bf16_bf16_bf16_mk_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_bf16_bf16_bf16_bf16_mk_nk_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_wmma_universal_bf16_bf16_bf16_bf16_mk_nk_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp new file mode 100644 index 0000000000..0702346c67 --- /dev/null +++ b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +using device_gemm_add_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +using device_gemm_add_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmDefault, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_add_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From c2077ca71d1f490772cbe8688a256c49e16d7fb7 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 18 Jun 2025 15:30:32 +0000 Subject: [PATCH 03/62] Modified the v3 code to added one fp16 bxdl instance. --- example/68_gemm_add/CMakeLists.txt | 2 + example/68_gemm_add/gemm_add_wmma_bf16.cpp | 7 - example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp | 379 ++++++++++++++--- example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp | 380 +++++++++++++++--- example/68_gemm_add/gemm_add_xdl_fp16.cpp | 326 +++++++++++++++ 5 files changed, 985 insertions(+), 109 deletions(-) create mode 100644 example/68_gemm_add/gemm_add_xdl_fp16.cpp diff --git a/example/68_gemm_add/CMakeLists.txt b/example/68_gemm_add/CMakeLists.txt index cb0b1fc457..78435c74b8 100644 --- a/example/68_gemm_add/CMakeLists.txt +++ b/example/68_gemm_add/CMakeLists.txt @@ -1,6 +1,8 @@ +add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp) add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp) add_example_executable(example_gemm_add_wmma_v3_fp16 gemm_add_wmma_v3_fp16.cpp) add_example_executable(example_gemm_add_wmma_v3_bf16 gemm_add_wmma_v3_bf16.cpp) + diff --git a/example/68_gemm_add/gemm_add_wmma_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_bf16.cpp index e3bfbd74cf..cca9e492d6 100644 --- a/example/68_gemm_add/gemm_add_wmma_bf16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_bf16.cpp @@ -64,13 +64,6 @@ struct Add const float y_tmp = x0 + x2_tmp; y = ck::type_convert(y_tmp); } - - template <> - __host__ __device__ constexpr void - operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const - { - y = x0 + x1; - }; }; template diff --git a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp index 4047566e9e..8d2fdd216e 100644 --- a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp @@ -1,66 +1,343 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include +#include +#include +#include + +#include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" -#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const ck::bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 + x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(ck::bhalf_t& y, const ck::bhalf_t& x0, const ck::bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(ck::bhalf_t& y, const float& x0, const ck::bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 + x2_tmp; + y = ck::type_convert(y_tmp); + } +}; template using S = ck::Sequence; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +using BF16 = ck::bhalf_t; +using F32 = float; -static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; -static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; -using device_gemm_add_bf16_bf16_bf16_bf16_mk_nk_mn_mn_generic_instances = std::tuple< - // clang-format off - //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| - //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| - //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | - DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using device_gemm_add_wmma_universal_bf16_bf16_bf16_bf16_mk_nk_mn_mn_instances = std::tuple< - // clang-format off - //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| - //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| - //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | - DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmDefault, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1> - // clang-format on - >; +using BF16_Tuple = ck::Tuple; -void add_device_gemm_add_wmma_universal_bf16_bf16_bf16_bf16_mk_nk_mn_mn_instances( - std::vector>>& instances) +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using DsDataType = BF16_Tuple; +using EDataType = BF16; + +using Row_Tuple = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + BF16, + BF16, + BF16_Tuple, + BF16, + F32, + F32, + PassThrough, + PassThrough, + Add, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +int main(int argc, char* argv[]) { - add_device_operation_instances( - instances, device_gemm_add_bf16_bf16_bf16_bf16_mk_nk_mn_mn_generic_instances{}); - add_device_operation_instances( - instances, device_gemm_add_wmma_universal_bf16_bf16_bf16_bf16_mk_nk_mn_mn_instances{}); -} + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE" + "beta\n"); + exit(0); + } + + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp index 0702346c67..9983328218 100644 --- a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp @@ -1,66 +1,344 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include +#include +#include +#include + +#include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" -#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/host_utility/device_prop.hpp" + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const ck::half_t& x1) const + { + y = x0 + ck::type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(ck::half_t& y, const float& x0, const float& x1) const + { + y = ck::type_convert(x0 + x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(ck::half_t& y, const float& x0, const ck::half_t& x1) const + { + y = ck::type_convert(x0) + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(ck::half_t& y, const ck::half_t& x0, const ck::half_t& x1) const + { + y = x0 + x1; + }; +}; template using S = ck::Sequence; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +using F16 = ck::half_t; +using F32 = float; -static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; -static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; -using device_gemm_add_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances = std::tuple< - // clang-format off - //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| - //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| - //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | - DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using device_gemm_add_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< - // clang-format off - //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| - //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| - //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | - DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmDefault, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1> - // clang-format on - >; +using F16_Tuple = ck::Tuple; -void add_device_gemm_add_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances( - std::vector>>& instances) +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = F16_Tuple; +using EDataType = F16; + +using Row_Tuple = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + F16, + F16, + F16_Tuple, + F16, + F32, + F32, + PassThrough, + PassThrough, + Add, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +int main(int argc, char* argv[]) { - add_device_operation_instances(instances, - device_gemm_add_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances{}); - add_device_operation_instances( - instances, device_gemm_add_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances{}); -} + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE" + "beta\n"); + exit(0); + } + + bool is_supported = ck::is_gfx11_supported(); + if(!is_supported) + { + std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() + << std::endl; + return 0; + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/68_gemm_add/gemm_add_xdl_fp16.cpp b/example/68_gemm_add/gemm_add_xdl_fp16.cpp new file mode 100644 index 0000000000..77c3040171 --- /dev/null +++ b/example/68_gemm_add/gemm_add_xdl_fp16.cpp @@ -0,0 +1,326 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const ck::half_t& x1) const + { + y = x0 + ck::type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(ck::half_t& y, const float& x0, const float& x1) const + { + y = ck::type_convert(x0 + x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(ck::half_t& y, const float& x0, const ck::half_t& x1) const + { + y = ck::type_convert(x0) + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(ck::half_t& y, const ck::half_t& x0, const ck::half_t& x1) const + { + y = x0 + x1; + }; +}; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} From 57c3fd96c4ac52573d33d7c45cc0771f9ace754a Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 18 Jun 2025 19:23:42 +0000 Subject: [PATCH 04/62] added bf16 xdl instance. --- example/68_gemm_add/CMakeLists.txt | 12 +- example/68_gemm_add/gemm_add_xdl_bf16.cpp | 325 ++++++++++++++++++++++ 2 files changed, 336 insertions(+), 1 deletion(-) create mode 100644 example/68_gemm_add/gemm_add_xdl_bf16.cpp diff --git a/example/68_gemm_add/CMakeLists.txt b/example/68_gemm_add/CMakeLists.txt index 78435c74b8..2cf152c893 100644 --- a/example/68_gemm_add/CMakeLists.txt +++ b/example/68_gemm_add/CMakeLists.txt @@ -1,8 +1,18 @@ -add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp) add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp) add_example_executable(example_gemm_add_wmma_v3_fp16 gemm_add_wmma_v3_fp16.cpp) add_example_executable(example_gemm_add_wmma_v3_bf16 gemm_add_wmma_v3_bf16.cpp) +add_custom_target(example_gemm_add_xdl) +set_source_files_properties(example_gemm_add_xdl_fp16/gemm_add_xdl_fp16.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +add_library(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) +add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_fp16) + +set_source_files_properties(example_gemm_add_xdl_bf16/gemm_add_xdl_bf16.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +add_library(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) +add_example_executable(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) +add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_bf16) + diff --git a/example/68_gemm_add/gemm_add_xdl_bf16.cpp b/example/68_gemm_add/gemm_add_xdl_bf16.cpp new file mode 100644 index 0000000000..e4213d8d2e --- /dev/null +++ b/example/68_gemm_add/gemm_add_xdl_bf16.cpp @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const ck::bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 + x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(ck::bhalf_t& y, const ck::bhalf_t& x0, const ck::bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(ck::bhalf_t& y, const float& x0, const ck::bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 + x2_tmp; + y = ck::type_convert(y_tmp); + } +}; + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} From b42b6b67a59e11783b6c215a1745094b00612ca4 Mon Sep 17 00:00:00 2001 From: Apoorva Kalyani Date: Mon, 26 May 2025 21:35:14 +0000 Subject: [PATCH 05/62] adding gemm_add wmma_cshuffle and other support (cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976) Co-authored-by: Cenxuan --- .../gpu/gemm_add.hpp | 55 ++++++++++++++- ...bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp | 69 +++++++++++++++++++ ...le_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp | 69 +++++++++++++++++++ 3 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp index 030f3c2760..12583f8d1a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -15,7 +15,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_USE_XDL void add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( std::vector>>&); +#elif defined(CK_USE_WMMA) +void add_device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>&); +#endif // GEMM + Add + template > op_ptrs; - +#ifdef CK_USE_XDL #if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -103,7 +130,31 @@ struct DeviceOperationInstanceFactory< } } #endif +#elif defined(CK_USE_WMMA) +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(op_ptrs); + } + } +#endif +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances(op_ptrs); + } + } +#endif +#endif return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..c35dd22ed1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances = std::tuple< + // clang-format off + // M/N/K padding + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, I32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..98baf5d089 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 113ea09770594df7ea7f9f97b0b9412a61dd41a2 Mon Sep 17 00:00:00 2001 From: Apoorva Kalyani Date: Mon, 26 May 2025 21:35:48 +0000 Subject: [PATCH 06/62] add instances into camkelists (cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e) Co-authored-by: Cenxuan --- .../src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt | 2 ++ ...dd_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp | 2 +- ...m_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index 298da1fbef..b01f3f05a8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -2,4 +2,6 @@ add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp index c35dd22ed1..7c87f54987 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -4,7 +4,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" #include "ck/utility/sequence.hpp" namespace ck { diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp index 98baf5d089..70abe3faf2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp @@ -4,7 +4,7 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" #include "ck/utility/sequence.hpp" namespace ck { From b129e731c30fcde3e30136ba12b412ebca31b288 Mon Sep 17 00:00:00 2001 From: Apoorva Kalyani Date: Mon, 26 May 2025 21:36:15 +0000 Subject: [PATCH 07/62] This is work in progress, edited the template parameters in order to build (cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41) Co-authored-by: Cenxuan --- .../gpu/gemm_add/CMakeLists.txt | 2 +- ...bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp | 4 +- ...le_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp | 30 ++++++----- profiler/src/CMakeLists.txt | 54 ++++++++++--------- 4 files changed, 50 insertions(+), 40 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index b01f3f05a8..371f47bf96 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -1,4 +1,4 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp index 7c87f54987..f1c571e2f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -27,7 +27,8 @@ using device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_insta //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, I32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // TODO: these template variables need to be adjusted + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, I32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> // clang-format on >; @@ -38,6 +39,7 @@ using device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = s //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // TODO: these template variables need to be adjusted DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp index 70abe3faf2..f29f9720f4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp @@ -23,25 +23,27 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial using device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< // clang-format off // M/N/K padding - //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // TODO: these template variables need to be adjusted + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> // clang-format on >; -using device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances = std::tuple< +using device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances = std::tuple< // clang-format off // M/N/K padding - //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // TODO: these template variables need to be adjusted + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on >; void add_device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 06ce589490..2eb1038357 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -40,19 +40,21 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_contraction_scale.cpp) endif() if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_OPS profile_gemm_reduce.cpp) - list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp) - list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_OPS profile_gemm_add.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm.cpp) - list(APPEND PROFILER_OPS profile_gemm_streamk.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp) - list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) - list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) endif() list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") @@ -92,6 +94,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp) list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) endif() endif() @@ -147,17 +150,19 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_contraction_scale_instance) endif() if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_INSTANCES device_gemm_add_instance) - list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) - list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance) - list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance) - list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance) - list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance) - list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance) endif() list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance) @@ -203,6 +208,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) endif() From 455275de80ec57e7ec6380cf5e20626dcd8f53fb Mon Sep 17 00:00:00 2001 From: Apoorva Kalyani Date: Mon, 26 May 2025 21:36:36 +0000 Subject: [PATCH 08/62] temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype (cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1) Co-authored-by: Cenxuan --- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 12 +++++++++++ .../gpu/gemm_add.hpp | 20 +++++++++++-------- .../gpu/gemm_add/CMakeLists.txt | 4 ++-- ...16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp} | 20 +++++++++---------- ..._f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp} | 20 +++++++++---------- 5 files changed, 46 insertions(+), 30 deletions(-) rename library/src/tensor_operation_instance/gpu/gemm_add/{device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp => device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp} (90%) rename library/src/tensor_operation_instance/gpu/gemm_add/{device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp => device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp} (71%) diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 429df2413f..f1baf223a1 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -16,11 +16,13 @@ enum struct WmmaInstr wmma_f32_16x16x16_bf16, wmma_f16_16x16x16_f16, wmma_bf16_16x16x16_bf16, + wmma_i32_16x16x16_iu16, wmma_i32_16x16x16_iu8, wmma_i32_16x16x16_iu4, // gfx12 wmma_f32_16x16x16_f16_gfx12, wmma_f32_16x16x16_bf16_gfx12, + wmma_i32_16x16x16_iu16_gfx12, wmma_i32_16x16x16_iu8_gfx12, wmma_f32_16x16x16_f8f8_gfx12, wmma_f32_16x16x16_f8bf8_gfx12, @@ -590,6 +592,16 @@ struct WmmaSelector return WmmaInstr::wmma_bf16_16x16x16_bf16; } + template <> + constexpr auto GetWmma() + { +#ifdef __gfx12__ + return WmmaInstr::wmma_i32_16x16x16_iu16_gfx12; +#else + return WmmaInstr::wmma_i32_16x16x16_iu16; +#endif + } + template <> constexpr auto GetWmma() { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp index 12583f8d1a..985cf2b1ec 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -43,7 +43,7 @@ void add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( Add>>>&); #elif defined(CK_USE_WMMA) -void add_device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( +void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( std::vector>>&); -void add_device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( +void add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances( std::vector && is_same_v && +// TODO: +// here for WMMA, currently BDataType and ADataType must be the same +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(op_ptrs); + add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(op_ptrs); } } #endif -#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16) - if constexpr(is_same_v && is_same_v && +#if defined(CK_ENABLE_BF16) +// TODO: +// here for WMMA, currently BDataType and ADataType must be the same + if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances(op_ptrs); + add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances(op_ptrs); } } #endif diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index 371f47bf96..5f8a418b4a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -2,6 +2,6 @@ add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp - device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp - device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp similarity index 90% rename from library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp index f1c571e2f0..4db8af6554 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -20,7 +20,7 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[k, n], d0[m, n], d1[m, n] -using device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances = std::tuple< +using device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_generic_instances = std::tuple< // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -28,11 +28,11 @@ using device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_insta //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // TODO: these template variables need to be adjusted - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, I32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, I32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 16, 16, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> // clang-format on >; -using device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< +using device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -40,19 +40,19 @@ using device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = s //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // TODO: these template variables need to be adjusted - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -void add_device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( +void add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances( std::vector>>& instances) { add_device_operation_instances( - instances, device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + instances, device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_generic_instances{}); add_device_operation_instances( - instances, device_gemm_add_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances{}); + instances, device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp similarity index 71% rename from library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index f29f9720f4..368d17002a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -20,7 +20,7 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[k, n], d0[m, n], d1[m, n] -using device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< +using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -28,11 +28,11 @@ using device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instance //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // TODO: these template variables need to be adjusted - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> // clang-format on >; -using device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances = std::tuple< +using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -40,19 +40,19 @@ using device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances = std: //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // TODO: these template variables need to be adjusted - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -void add_device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( +void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( std::vector>>& instances) { add_device_operation_instances( - instances, device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances{}); + instances, device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances{}); add_device_operation_instances( - instances, device_gemm_add_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances{}); + instances, device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); } } // namespace instance From 1fda4990ad37e68f74a23ff5f6a56d2fa5291df5 Mon Sep 17 00:00:00 2001 From: Apoorva Kalyani Date: Mon, 26 May 2025 21:39:41 +0000 Subject: [PATCH 09/62] added datatype and use clang-format-12 (cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec) Co-authored-by: Cenxuan --- include/ck/utility/dtype_vector.hpp | 8 ++++++++ .../ck/library/tensor_operation_instance/gpu/gemm_add.hpp | 7 ++++--- ..._c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 3 ++- ...mma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 2 +- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 65eed0624c..4a91054fd9 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -2124,6 +2124,14 @@ using int32x16_t = typename vector_type::type; using int32x32_t = typename vector_type::type; using int32x64_t = typename vector_type::type; +// i16 +using int16x2_t = typename vector_type::type; +using int16x4_t = typename vector_type::type; +using int16x8_t = typename vector_type::type; +using int16x16_t = typename vector_type::type; +using int16x32_t = typename vector_type::type; +using int16x64_t = typename vector_type::type; + // i8 using int8x2_t = typename vector_type::type; using int8x4_t = typename vector_type::type; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp index 985cf2b1ec..4b50d6a351 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -146,15 +146,16 @@ struct DeviceOperationInstanceFactory< #endif #if defined(CK_ENABLE_BF16) -// TODO: -// here for WMMA, currently BDataType and ADataType must be the same + // TODO: + // here for WMMA, currently BDataType and ADataType must be the same if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances(op_ptrs); + add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); } } #endif diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp index 4db8af6554..50c90e781d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -60,7 +60,8 @@ void add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances Add>>>& instances) { add_device_operation_instances( - instances, device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + instances, + device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_generic_instances{}); add_device_operation_instances( instances, device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances{}); } diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 368d17002a..6dde4d8a87 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -43,7 +43,7 @@ using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on + // clang-format on >; void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( From 1519eaaa2a2206697ae1ea32933297d8804d4e41 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 28 May 2025 10:42:19 +0000 Subject: [PATCH 10/62] Fixing build errors --- include/ck/utility/amd_wmma.hpp | 151 ++++++++---------- include/ck/utility/dtype_vector.hpp | 2 + ...f16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 16 +- 3 files changed, 78 insertions(+), 91 deletions(-) diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index e14c0d62a8..f6f6712b9a 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -143,6 +143,33 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> } }; +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu16_w32; + +template +struct intrin_wmma_i32_16x16x16_iu16_w32<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int16x16_t& reg_a, const int16x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu16_w32( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + /********************************WAVE64 MODE***********************************************/ template @@ -263,6 +290,33 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> } }; +// src: iu16, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu16_w64; + +template +struct intrin_wmma_i32_16x16x16_iu16_w64<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int16x16_t& reg_a, const int16x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + // gfx12 /********************************WAVE32 MODE***********************************************/ @@ -341,94 +395,25 @@ struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> } }; -// src: f8, f8, dst: fp32 -template -struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12; +// src: iu16, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu16_w32_gfx12; -template <> -struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<16, 16> +template +struct intrin_wmma_i32_16x16x16_iu16_w32_gfx12<16, 16, neg_a, neg_b, clamp> { template - __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + __device__ static void Run(const int16x8_t& reg_a, const int16x8_t& reg_b, FloatC& reg_c) { #if defined(__gfx12__) - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( - bit_cast(reg_a), - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}]); -#else - ignore = reg_a; - ignore = reg_b; - ignore = reg_c; -#endif - } -}; - -// src: f8, bf8, dst: fp32 -template -struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12; - -template <> -struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<16, 16> -{ - template - __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) - { -#if defined(__gfx12__) - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12( - bit_cast(reg_a), - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}]); -#else - ignore = reg_a; - ignore = reg_b; - ignore = reg_c; -#endif - } -}; - -// src: bf8, f8, dst: fp32 -template -struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12; - -template <> -struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<16, 16> -{ - template - __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) - { -#if defined(__gfx12__) - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12( - bit_cast(reg_a), - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}]); -#else - ignore = reg_a; - ignore = reg_b; - ignore = reg_c; -#endif - } -}; - -// src: bf8, bf8, dst: fp32 -template -struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12; - -template <> -struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<16, 16> -{ - template - __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) - { -#if defined(__gfx12__) - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12( - bit_cast(reg_a), - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}]); + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); #else ignore = reg_a; ignore = reg_b; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 4a91054fd9..3342e43a48 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -2093,6 +2093,8 @@ struct vector_type()>> } }; +using int64_t = long; + // fp32 using float2_t = typename vector_type::type; using float4_t = typename vector_type::type; diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp index 50c90e781d..32d5ae580f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -28,11 +28,11 @@ using device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_generic_inst //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // TODO: these template variables need to be adjusted - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, I32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 16, 16, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8> // clang-format on >; -using device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< +using add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -40,13 +40,13 @@ using device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances = //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // TODO: these template variables need to be adjusted - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -void add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances( +void device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances( std::vector Date: Wed, 11 Jun 2025 09:50:19 +0000 Subject: [PATCH 11/62] Added instances for v3 --- ...f16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 73 +++++++++++++++++++ ...3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 72 ++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..96c0150dec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; +using BF16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +using device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_bf16_mk_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, BF16, PassThrough, PassThrough, AddMultiply, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, BF16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_bf16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_bf16_mk_kn_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..0e9d2bdaf1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +using device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck \ No newline at end of file From 7da9f64ed0f466156d4c41abf28fff89b2f41b54 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 11 Jun 2025 11:18:51 +0000 Subject: [PATCH 12/62] Adding instances and executables --- .../src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt | 2 ++ profiler/src/CMakeLists.txt | 2 ++ 2 files changed, 4 insertions(+) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index 5f8a418b4a..46a2e56726 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -4,4 +4,6 @@ add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp ) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 2eb1038357..7c381a4093 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -65,6 +65,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_batched_gemm.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) + list(APPEND PROFILER_OPS profile_gemm_add.cpp) list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_splitk.cpp) list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) @@ -179,6 +180,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_instance) list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance) list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance) list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance) From 0cce81c3e751e18870c734a0abbbd2426bfeb341 Mon Sep 17 00:00:00 2001 From: apoorva Date: Thu, 12 Jun 2025 22:19:41 +0000 Subject: [PATCH 13/62] Code update of template parameters modified. --- .../gpu/gemm_add.hpp | 4 +- .../gpu/gemm_add/CMakeLists.txt | 4 +- ...6_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp} | 10 +-- ...e_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 10 +-- ...f16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 66 ++++++++----------- ...3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 59 ++++++++--------- 6 files changed, 69 insertions(+), 84 deletions(-) rename library/src/tensor_operation_instance/gpu/gemm_add/{device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp => device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp} (94%) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp index 4b50d6a351..5da5d9ff91 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -56,7 +56,7 @@ void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( PassThrough, Add>>>&); -void add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances( +void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( std::vector; -using add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< +using add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -46,7 +46,7 @@ using add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance // clang-format on >; -void device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances( +void device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8> // clang-format on >; @@ -40,9 +40,9 @@ using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // TODO: these template variables need to be adjusted - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp index 96c0150dec..8b5fd6e47e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -1,55 +1,44 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using BF16 = ck::bhalf_t; -using F32 = float; -using BF16_Tuple = ck::Tuple; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Row_Tuple = ck::Tuple; - template using S = ck::Sequence; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; - static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; -using device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_bf16_mk_kn_mn_mn_mn_instances = - std::tuple< - // clang-format off - //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| - //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| - //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | - DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, BF16, PassThrough, PassThrough, AddMultiply, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, BF16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; -void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_bf16_mk_kn_mn_mn_mn_instances( +template +using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( std::vector>>& instances) + Add>>>& instances) { add_device_operation_instances( instances, - device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_bf16_mk_kn_mn_mn_mn_instances{}); + device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances{}); } } // namespace instance } // namespace device } // namespace tensor_operation -} // namespace ck \ No newline at end of file +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 0e9d2bdaf1..67d7db0390 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -1,54 +1,44 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F16 = ck::half_t; -using F32 = float; -using F16_Tuple = ck::Tuple; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using Row_Tuple = ck::Tuple; - template using S = ck::Sequence; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; - static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; -using device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = std::tuple< +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template +using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< // clang-format off - //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| - //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| - //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | - DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1> + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> // clang-format on >; -void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( +void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( std::vector>>& instances) + Add>>>& instances) { add_device_operation_instances( instances, - device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances{}); + device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); } } // namespace instance } // namespace device } // namespace tensor_operation -} // namespace ck \ No newline at end of file +} // namespace ck From 6df313f70f207424ab3ae67f2771e193c3f86c62 Mon Sep 17 00:00:00 2001 From: apoorva Date: Thu, 12 Jun 2025 22:20:15 +0000 Subject: [PATCH 14/62] Renamed file. --- ...mma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename library/src/tensor_operation_instance/gpu/gemm_add/{device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp => device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp} (100%) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_f16_bf16_bf16_mk_kn_mn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp From 06d44f1f1cf18e0218db5090d89060f4317fd116 Mon Sep 17 00:00:00 2001 From: apoorva Date: Fri, 13 Jun 2025 08:31:22 +0000 Subject: [PATCH 15/62] Added tests. --- test/gemm_add/CMakeLists.txt | 5 ++++ test/gemm_add/test_gemm_add_wmma.cpp | 40 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 test/gemm_add/test_gemm_add_wmma.cpp diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index 88a0cfd0e2..662fe53a62 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -30,6 +30,11 @@ if(result EQUAL 0) target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance) endif() +add_gtest_executable(test_gemm_add_wmma test_gemm_add_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_wmma PRIVATE utility device_gemm_add_instance) +endif() + add_gtest_executable(test_gemm_add_add_fastgelu_wmma test_gemm_add_add_fastgelu_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_add_add_fastgelu_wmma PRIVATE utility device_gemm_add_add_fastgelu_instance) diff --git a/test/gemm_add/test_gemm_add_wmma.cpp b/test/gemm_add/test_gemm_add_wmma.cpp new file mode 100644 index 0000000000..a2238463f5 --- /dev/null +++ b/test/gemm_add/test_gemm_add_wmma.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAdd : public TestGemmD0Common +{ + private: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl; + + decltype(ProfileGemmAddImpl) GetImpl() override { return ProfileGemmAddImpl; } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); +TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); } From 10d648a037b8343ddc98ea47f41cc48ee563dff9 Mon Sep 17 00:00:00 2001 From: apoorva Date: Fri, 13 Jun 2025 09:59:49 +0000 Subject: [PATCH 16/62] resolved error tests. --- test/gemm_add/test_gemm_add_wmma.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gemm_add/test_gemm_add_wmma.cpp b/test/gemm_add/test_gemm_add_wmma.cpp index a2238463f5..57329cbf87 100644 --- a/test/gemm_add/test_gemm_add_wmma.cpp +++ b/test/gemm_add/test_gemm_add_wmma.cpp @@ -33,7 +33,7 @@ class TestGemmAdd : public TestGemmD0Common decltype(ProfileGemmAddImpl) GetImpl() override { return ProfileGemmAddImpl; } }; -using KernelTypes = ::testing::Types, +using KernelTypes = ::testing::Types, std::tuple>; TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); From ef781db3050dfbadffd3044685f54720c48a848c Mon Sep 17 00:00:00 2001 From: apoorva Date: Fri, 13 Jun 2025 14:49:56 +0000 Subject: [PATCH 17/62] Fixing build errors --- .../ck/library/tensor_operation_instance/gpu/gemm_add.hpp | 2 +- ...a_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 6 +++--- test/gemm_add/test_gemm_add_wmma.cpp | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp index 5da5d9ff91..53ec2eacee 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -154,7 +154,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_gemm_add_wmma_c_shuffle_bf16_f16_bf16_bf16_mk_kn_mn_mn_instances( + add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( op_ptrs); } } diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp index 7f71dc8a9b..496ec09836 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -32,7 +32,7 @@ using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_ins // clang-format on >; -using add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< +using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -46,7 +46,7 @@ using add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instanc // clang-format on >; -void device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( +void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( std::vector decltype(ProfileGemmAddImpl) GetImpl() override { return ProfileGemmAddImpl; } }; -using KernelTypes = ::testing::Types, - std::tuple>; +using KernelTypes = + ::testing::Types, Row>, + std::tuple, Row>>; TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); } From bd49ec04d42e5e12f897c9afeee3aeb1aa56f180 Mon Sep 17 00:00:00 2001 From: apoorva Date: Fri, 13 Jun 2025 15:21:08 +0000 Subject: [PATCH 18/62] Updated comments --- ...wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 3 --- ...add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 3 --- 2 files changed, 6 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp index 496ec09836..e4e2b84883 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -17,9 +17,6 @@ using S = ck::Sequence; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -// e = elementwise((a * b), d0, d1) -// outout: e[m, n] -// input: a[m, k], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances = std::tuple< // clang-format off // M/N/K padding diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 9b5a805da4..90b347d5f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -17,9 +17,6 @@ using S = ck::Sequence; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -// e = elementwise((a * b), d0, d1) -// outout: e[m, n] -// input: a[m, k], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< // clang-format off // M/N/K padding From 3301ef501fbac2d25918bb432b6a0cc62a14dcfb Mon Sep 17 00:00:00 2001 From: apoorva Date: Thu, 19 Jun 2025 09:32:02 +0000 Subject: [PATCH 19/62] removed the changes as per the MR review comment. --- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 12 -- include/ck/utility/amd_wmma.hpp | 151 ++++++++++-------- include/ck/utility/dtype_vector.hpp | 10 -- 3 files changed, 83 insertions(+), 90 deletions(-) diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index f1baf223a1..429df2413f 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -16,13 +16,11 @@ enum struct WmmaInstr wmma_f32_16x16x16_bf16, wmma_f16_16x16x16_f16, wmma_bf16_16x16x16_bf16, - wmma_i32_16x16x16_iu16, wmma_i32_16x16x16_iu8, wmma_i32_16x16x16_iu4, // gfx12 wmma_f32_16x16x16_f16_gfx12, wmma_f32_16x16x16_bf16_gfx12, - wmma_i32_16x16x16_iu16_gfx12, wmma_i32_16x16x16_iu8_gfx12, wmma_f32_16x16x16_f8f8_gfx12, wmma_f32_16x16x16_f8bf8_gfx12, @@ -592,16 +590,6 @@ struct WmmaSelector return WmmaInstr::wmma_bf16_16x16x16_bf16; } - template <> - constexpr auto GetWmma() - { -#ifdef __gfx12__ - return WmmaInstr::wmma_i32_16x16x16_iu16_gfx12; -#else - return WmmaInstr::wmma_i32_16x16x16_iu16; -#endif - } - template <> constexpr auto GetWmma() { diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index f6f6712b9a..e14c0d62a8 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -143,33 +143,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> } }; -// src: iu8, dst: i32 -template -struct intrin_wmma_i32_16x16x16_iu16_w32; - -template -struct intrin_wmma_i32_16x16x16_iu16_w32<16, 16, neg_a, neg_b, clamp> -{ - template - __device__ static void Run(const int16x16_t& reg_a, const int16x16_t& reg_b, FloatC& reg_c) - { -#if defined(__gfx11__) - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_wmma_i32_16x16x16_iu16_w32( - neg_a, - bit_cast(reg_a), - neg_b, - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}], - clamp); -#else - ignore = reg_a; - ignore = reg_b; - ignore = reg_c; -#endif - } -}; - /********************************WAVE64 MODE***********************************************/ template @@ -290,33 +263,6 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> } }; -// src: iu16, dst: i32 -template -struct intrin_wmma_i32_16x16x16_iu16_w64; - -template -struct intrin_wmma_i32_16x16x16_iu16_w64<16, 16, neg_a, neg_b, clamp> -{ - template - __device__ static void Run(const int16x16_t& reg_a, const int16x16_t& reg_b, FloatC& reg_c) - { -#if defined(__gfx11__) - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( - neg_a, - bit_cast(reg_a), - neg_b, - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}], - clamp); -#else - ignore = reg_a; - ignore = reg_b; - ignore = reg_c; -#endif - } -}; - // gfx12 /********************************WAVE32 MODE***********************************************/ @@ -395,25 +341,94 @@ struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> } }; -// src: iu16, dst: i32 -template -struct intrin_wmma_i32_16x16x16_iu16_w32_gfx12; +// src: f8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12; -template -struct intrin_wmma_i32_16x16x16_iu16_w32_gfx12<16, 16, neg_a, neg_b, clamp> +template <> +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<16, 16> { template - __device__ static void Run(const int16x8_t& reg_a, const int16x8_t& reg_b, FloatC& reg_c) + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { #if defined(__gfx12__) - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - neg_a, - bit_cast(reg_a), - neg_b, - bit_cast(reg_b), - reg_c.template AsType()[Number<0>{}], - clamp); + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: f8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); #else ignore = reg_a; ignore = reg_b; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 3342e43a48..65eed0624c 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -2093,8 +2093,6 @@ struct vector_type()>> } }; -using int64_t = long; - // fp32 using float2_t = typename vector_type::type; using float4_t = typename vector_type::type; @@ -2126,14 +2124,6 @@ using int32x16_t = typename vector_type::type; using int32x32_t = typename vector_type::type; using int32x64_t = typename vector_type::type; -// i16 -using int16x2_t = typename vector_type::type; -using int16x4_t = typename vector_type::type; -using int16x8_t = typename vector_type::type; -using int16x16_t = typename vector_type::type; -using int16x32_t = typename vector_type::type; -using int16x64_t = typename vector_type::type; - // i8 using int8x2_t = typename vector_type::type; using int8x4_t = typename vector_type::type; From 38d00277c275cd10f743b7e57d67ba37504824dd Mon Sep 17 00:00:00 2001 From: apoorva Date: Thu, 19 Jun 2025 10:56:18 +0000 Subject: [PATCH 20/62] Updated tests. --- .../gpu/gemm_add.hpp | 4 +- test/gemm_add/test_gemm_add_wmma.cpp | 41 ++++++++----------- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp index 53ec2eacee..4b7bdeb5af 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -49,7 +49,7 @@ void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( Row_Tuple, Row, F16, - I8, + F16, F16_Tuple, F16, PassThrough, @@ -62,7 +62,7 @@ void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance Row_Tuple, Row, BF16, - I8, + BF16, BF16_Tuple, BF16, PassThrough, diff --git a/test/gemm_add/test_gemm_add_wmma.cpp b/test/gemm_add/test_gemm_add_wmma.cpp index 143b5f8a48..5b6e9629a3 100644 --- a/test/gemm_add/test_gemm_add_wmma.cpp +++ b/test/gemm_add/test_gemm_add_wmma.cpp @@ -9,33 +9,26 @@ template class TestGemmAdd : public TestGemmD0Common { - private: - using ADataType = std::tuple_element_t<0, Tuple>; - using BDataType = std::tuple_element_t<1, Tuple>; - using AccDataType = std::tuple_element_t<2, Tuple>; - using D0DataType = std::tuple_element_t<3, Tuple>; - using EDataType = std::tuple_element_t<4, Tuple>; - using ALayout = std::tuple_element_t<5, Tuple>; - using BLayout = std::tuple_element_t<6, Tuple>; - using D0Layout = std::tuple_element_t<7, Tuple>; - using ELayout = std::tuple_element_t<8, Tuple>; + using ProfileCall = typename TestGemmD0Common::ProfileCall; - constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl; - - decltype(ProfileGemmAddImpl) GetImpl() override { return ProfileGemmAddImpl; } + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_impl::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } }; -using KernelTypes = - ::testing::Types, Row>, - std::tuple, Row>>; +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); } From c8b3f3d587b1464c8a776880c5f2cccf4f9c3447 Mon Sep 17 00:00:00 2001 From: apoorva Date: Thu, 19 Jun 2025 12:35:33 +0000 Subject: [PATCH 21/62] Restored the Cmake file that was reverted by mistake during rebase. --- profiler/src/CMakeLists.txt | 54 +++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 7c381a4093..e75a10ad19 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -40,21 +40,19 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_contraction_scale.cpp) endif() if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) + list(APPEND PROFILER_OPS profile_gemm_reduce.cpp) + list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp) + list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) + list(APPEND PROFILER_OPS profile_gemm_add.cpp) + list(APPEND PROFILER_OPS profile_grouped_gemm.cpp) + list(APPEND PROFILER_OPS profile_gemm_streamk.cpp) + list(APPEND PROFILER_OPS profile_gemm_add_relu.cpp) + list(APPEND PROFILER_OPS profile_gemm_add_silu.cpp) + list(APPEND PROFILER_OPS profile_gemm_add_relu_add_layernorm.cpp) + list(APPEND PROFILER_OPS profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_OPS profile_grouped_gemm_fastgelu.cpp) + list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) + list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) endif() list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") @@ -151,19 +149,17 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_contraction_scale_instance) endif() if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_instance) + list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance) + list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance) + list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance) + list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance) + list(APPEND DEVICE_INSTANCES device_grouped_gemm_fixed_nk_instance) + list(APPEND DEVICE_INSTANCES device_grouped_gemm_fastgelu_instance) + list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance) endif() list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance) @@ -210,7 +206,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_instance) list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) endif() From 78c2ee2ff3ed9a04fb17b885f017635097f8ce6a Mon Sep 17 00:00:00 2001 From: apoorva Date: Thu, 19 Jun 2025 13:23:04 +0000 Subject: [PATCH 22/62] Updated comments. --- ..._c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 7 ++++--- ...wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 5 +++-- ...shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 4 ++++ ...a_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 4 ++++ 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp index e4e2b84883..5e48249815 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -17,6 +17,9 @@ using S = ck::Sequence; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances = std::tuple< // clang-format off // M/N/K padding @@ -24,8 +27,7 @@ using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_ins //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // TODO: these template variables need to be adjusted - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8> // clang-format on >; @@ -36,7 +38,6 @@ using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // TODO: these template variables need to be adjusted DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 90b347d5f0..27dff0df22 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -17,6 +17,9 @@ using S = ck::Sequence; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< // clang-format off // M/N/K padding @@ -24,7 +27,6 @@ using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instanc //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // TODO: these template variables need to be adjusted DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8> // clang-format on >; @@ -36,7 +38,6 @@ using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // TODO: these template variables need to be adjusted DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp index 8b5fd6e47e..b3f862f9cd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -25,6 +25,10 @@ static constexpr auto V1 = BlockGemmPipelineVersion::v1; static constexpr auto V3 = BlockGemmPipelineVersion::v3; template + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< // clang-format off //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 67d7db0390..ec8fe54888 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -25,6 +25,10 @@ static constexpr auto V1 = BlockGemmPipelineVersion::v1; static constexpr auto V3 = BlockGemmPipelineVersion::v3; template + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< // clang-format off //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| From ed5ac2154161a5ca1237bc43d3a72167bc9a1db9 Mon Sep 17 00:00:00 2001 From: apoorva Date: Thu, 19 Jun 2025 13:34:38 +0000 Subject: [PATCH 23/62] Updated the template parameter description --- ...c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 8 ++++---- ...mma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp index 5e48249815..ed8a8d219b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -24,8 +24,8 @@ using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_ins // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NwmmaPerWave| _MBlock_MWaveMPerwmma| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerwmma| _NWaveNPerwmma| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8> // clang-format on @@ -35,8 +35,8 @@ using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MwmmaPerWave| NwmmaPerWave| _MBlock_MWaveMPerwmma| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerwmma| _NWaveNPerwmma| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 27dff0df22..d1aa066792 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -24,8 +24,8 @@ using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instanc // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MwmmaPerWave| NwmmaPerWave| _MBlock_MWaveMPerwmma| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerwmma| _NWaveNPerwmma| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8> // clang-format on @@ -35,8 +35,8 @@ using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std // clang-format off // M/N/K padding //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MwmmaPerWave| NwmmaPerWave| _MBlock_MWaveMPerwmma| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerwmma| _NWaveNPerwmma| //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, From 71d65d42944309b250f6c78de3cc9478ee5b1f95 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 25 Jun 2025 08:24:32 +0000 Subject: [PATCH 24/62] Updated tests to ad BF16 instances as per review comment --- test/gemm_add/test_gemm_add_wmma.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/gemm_add/test_gemm_add_wmma.cpp b/test/gemm_add/test_gemm_add_wmma.cpp index 5b6e9629a3..f4d29311b8 100644 --- a/test/gemm_add/test_gemm_add_wmma.cpp +++ b/test/gemm_add/test_gemm_add_wmma.cpp @@ -25,10 +25,9 @@ class TestGemmAdd : public TestGemmD0Common } }; -using KernelTypes = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = + ::testing::Types, Row>, + std::tuple, Row>>; TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); } From ee8c278734ccae0b4367e427c27b9f4ac77afea6 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 25 Jun 2025 09:44:18 +0000 Subject: [PATCH 25/62] Added include file and cleaned up(as per review comment) --- example/68_gemm_add/common.hpp | 105 +++++++ example/68_gemm_add/gemm_add_wmma_bf16.cpp | 276 +----------------- example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp | 261 +---------------- example/68_gemm_add/run_gem_add_example.inc | 143 +++++++++ 4 files changed, 261 insertions(+), 524 deletions(-) create mode 100644 example/68_gemm_add/common.hpp create mode 100644 example/68_gemm_add/run_gem_add_example.inc diff --git a/example/68_gemm_add/common.hpp b/example/68_gemm_add/common.hpp new file mode 100644 index 0000000000..745a0265db --- /dev/null +++ b/example/68_gemm_add/common.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + struct ExecutionConfig final + { + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + }; + + inline bool + parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) + { + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + std::cerr + << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, " + "StrideE" + << std::endl; + return false; + } + + return true; + } +} \ No newline at end of file diff --git a/example/68_gemm_add/gemm_add_wmma_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_bf16.cpp index cca9e492d6..80bd15bb98 100644 --- a/example/68_gemm_add/gemm_add_wmma_bf16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_bf16.cpp @@ -1,81 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/host_utility/device_prop.hpp" - -struct Add -{ - template - __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(double& y, const double& x0, const double& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const ck::bhalf_t& x1) const - { - const float x1_tmp = ck::type_convert(x1); - y = x0 + x1_tmp; - } - - template <> - __host__ __device__ constexpr void - operator()(ck::bhalf_t& y, const ck::bhalf_t& x0, const ck::bhalf_t& x1) const - { - const float x1_tmp = ck::type_convert(x0); - const float x2_tmp = ck::type_convert(x1); - const float y_tmp = x1_tmp + x2_tmp; - y = ck::type_convert(y_tmp); - } - - template <> - __host__ __device__ constexpr void - operator()(ck::bhalf_t& y, const float& x0, const ck::bhalf_t& x1) const - { - const float x2_tmp = ck::type_convert(x1); - const float y_tmp = x0 + x2_tmp; - y = ck::type_convert(y_tmp); - } -}; - -template -using S = ck::Sequence; - -using BF16 = ck::bhalf_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#include "common.hpp" using ADataType = BF16; using BDataType = BF16; @@ -139,196 +65,16 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ S<1, 32, 1, 4>, 8>; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = true; +// clang-format on - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; +#include "run_gem_add_example.inc" - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE" - "beta\n"); - exit(0); - } - - bool is_supported = ck::is_gfx11_supported(); - if(!is_supported) - { - std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() - << std::endl; - return 0; - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{StrideD}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << device_op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(do_verification) - { - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} +int main(int argc, char* argv[]) { return !run_gem_add_example(argc, argv); } \ No newline at end of file diff --git a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp index 8d2fdd216e..083a555488 100644 --- a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp @@ -1,70 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/host_utility/device_prop.hpp" - -struct Add -{ - template - __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(double& y, const double& x0, const double& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const ck::bhalf_t& x1) const - { - const float x1_tmp = ck::type_convert(x1); - y = x0 + x1_tmp; - } - - template <> - __host__ __device__ constexpr void - operator()(ck::bhalf_t& y, const ck::bhalf_t& x0, const ck::bhalf_t& x1) const - { - const float x1_tmp = ck::type_convert(x0); - const float x2_tmp = ck::type_convert(x1); - const float y_tmp = x1_tmp + x2_tmp; - y = ck::type_convert(y_tmp); - } - - template <> - __host__ __device__ constexpr void - operator()(ck::bhalf_t& y, const float& x0, const ck::bhalf_t& x1) const - { - const float x2_tmp = ck::type_convert(x1); - const float y_tmp = x0 + x2_tmp; - y = ck::type_convert(y_tmp); - } -}; +#include "common.hpp" template using S = ck::Sequence; @@ -146,198 +83,4 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = true; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE" - "beta\n"); - exit(0); - } - - bool is_supported = ck::is_gfx11_supported(); - if(!is_supported) - { - std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() - << std::endl; - return 0; - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{StrideD}, - StrideE, - 1, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << device_op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(do_verification) - { - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} +} \ No newline at end of file diff --git a/example/68_gemm_add/run_gem_add_example.inc b/example/68_gemm_add/run_gem_add_example.inc new file mode 100644 index 0000000000..5ee3ca3318 --- /dev/null +++ b/example/68_gemm_add/run_gem_add_example.inc @@ -0,0 +1,143 @@ +#pragma once + +bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} + +bool run_gemm_add_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || + run_gemm_add_multiply(problem_size, config); +} From 7840db41d07d64c78b40711b0e2b7c142d3753d8 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 25 Jun 2025 13:17:32 +0000 Subject: [PATCH 26/62] Updated and optimized the example code for all types. --- example/68_gemm_add/common.hpp | 119 ++++---- example/68_gemm_add/gemm_add_wmma_bf16.cpp | 10 +- example/68_gemm_add/gemm_add_wmma_fp16.cpp | 260 +---------------- example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp | 7 +- example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp | 261 +----------------- example/68_gemm_add/gemm_add_xdl_bf16.cpp | 249 +---------------- example/68_gemm_add/run_gem_add_example.inc | 12 +- .../68_gemm_add/run_gemm_add_example_v3.inc | 144 ++++++++++ 8 files changed, 233 insertions(+), 829 deletions(-) create mode 100644 example/68_gemm_add/run_gemm_add_example_v3.inc diff --git a/example/68_gemm_add/common.hpp b/example/68_gemm_add/common.hpp index 745a0265db..4435503e6b 100644 --- a/example/68_gemm_add/common.hpp +++ b/example/68_gemm_add/common.hpp @@ -12,7 +12,19 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#ifndef CK_USE_XDL #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#endif + +#ifndef CK_USE_MULTIPLE_D_WMMA +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#endif + +#ifndef CK_USE_WMMA_V3 +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#endif + #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/utility/data_type.hpp" @@ -46,60 +58,59 @@ struct ProblemSize final ck::index_t StrideB = 4096; ck::index_t StrideD = 4096; ck::index_t StrideE = 4096; +}; +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; - struct ExecutionConfig final +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) { - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - }; - - inline bool - parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) - { - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - } - else - { - std::cerr - << "arg1: verification (0=no, 1=yes)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, " - "StrideE" - << std::endl; - return false; - } - - return true; + // use default case } -} \ No newline at end of file + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideD = std::stoi(argv[9]); + problem_size.StrideE = std::stoi(argv[10]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD," + "StrideE" + << std::endl; + return false; + } + + return true; +} diff --git a/example/68_gemm_add/gemm_add_wmma_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_bf16.cpp index 80bd15bb98..bf9fc119f7 100644 --- a/example/68_gemm_add/gemm_add_wmma_bf16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_bf16.cpp @@ -67,14 +67,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - #include "run_gem_add_example.inc" -int main(int argc, char* argv[]) { return !run_gem_add_example(argc, argv); } \ No newline at end of file +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_wmma_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_fp16.cpp index 4d6b94ac29..3a6d40ea49 100644 --- a/example/68_gemm_add/gemm_add_wmma_fp16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_fp16.cpp @@ -1,71 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/host_utility/device_prop.hpp" - -struct Add -{ - template - __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(double& y, const double& x0, const double& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const ck::half_t& x1) const - { - y = x0 + ck::type_convert(x1); - }; - - template <> - __host__ __device__ constexpr void - operator()(ck::half_t& y, const float& x0, const float& x1) const - { - y = ck::type_convert(x0 + x1); - }; - - template <> - __host__ __device__ constexpr void - operator()(ck::half_t& y, const float& x0, const ck::half_t& x1) const - { - y = ck::type_convert(x0) + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(ck::half_t& y, const ck::half_t& x0, const ck::half_t& x1) const - { - y = x0 + x1; - }; -}; +#include "common.hpp" template using S = ck::Sequence; @@ -140,196 +76,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ S<1, 32, 1, 4>, 8>; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = true; +// clang-format on - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; +#include "run_gem_add_example.inc" - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE" - "beta\n"); - exit(0); - } - - bool is_supported = ck::is_gfx11_supported(); - if(!is_supported) - { - std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() - << std::endl; - return 0; - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{StrideD}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << device_op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(do_verification) - { - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp index 083a555488..7a4204e12d 100644 --- a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp @@ -83,4 +83,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; -} \ No newline at end of file + +// clang-format on + +#include "run_gemm_add_example_v3.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp index 9983328218..c44d124343 100644 --- a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp @@ -1,71 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/host_utility/device_prop.hpp" - -struct Add -{ - template - __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(double& y, const double& x0, const double& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const ck::half_t& x1) const - { - y = x0 + ck::type_convert(x1); - }; - - template <> - __host__ __device__ constexpr void - operator()(ck::half_t& y, const float& x0, const float& x1) const - { - y = ck::type_convert(x0 + x1); - }; - - template <> - __host__ __device__ constexpr void - operator()(ck::half_t& y, const float& x0, const ck::half_t& x1) const - { - y = ck::type_convert(x0) + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(ck::half_t& y, const ck::half_t& x0, const ck::half_t& x1) const - { - y = x0 + x1; - }; -}; +#include "common.hpp" template using S = ck::Sequence; @@ -148,197 +84,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = true; +// clang-format on - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; +#include "run_gemm_add_example_v3.inc" - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE" - "beta\n"); - exit(0); - } - - bool is_supported = ck::is_gfx11_supported(); - if(!is_supported) - { - std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name() - << std::endl; - return 0; - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{StrideD}, - StrideE, - 1, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << device_op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(do_verification) - { - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } \ No newline at end of file diff --git a/example/68_gemm_add/gemm_add_xdl_bf16.cpp b/example/68_gemm_add/gemm_add_xdl_bf16.cpp index e4213d8d2e..f5bfc14ebc 100644 --- a/example/68_gemm_add/gemm_add_xdl_bf16.cpp +++ b/example/68_gemm_add/gemm_add_xdl_bf16.cpp @@ -1,69 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" - -struct Add -{ - template - __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(double& y, const double& x0, const double& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const ck::bhalf_t& x1) const - { - const float x1_tmp = ck::type_convert(x1); - y = x0 + x1_tmp; - } - - template <> - __host__ __device__ constexpr void - operator()(ck::bhalf_t& y, const ck::bhalf_t& x0, const ck::bhalf_t& x1) const - { - const float x1_tmp = ck::type_convert(x0); - const float x2_tmp = ck::type_convert(x1); - const float y_tmp = x1_tmp + x2_tmp; - y = ck::type_convert(y_tmp); - } - - template <> - __host__ __device__ constexpr void - operator()(ck::bhalf_t& y, const float& x0, const ck::bhalf_t& x1) const - { - const float x2_tmp = ck::type_convert(x1); - const float y_tmp = x0 + x2_tmp; - y = ck::type_convert(y_tmp); - } -}; +#include "common.hpp" template using S = ck::Sequence; @@ -139,187 +77,6 @@ using DeviceOpInstance = S<1, 32, 1, 8>, 8>; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; +#include "run_gem_add_example.inc" - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{StrideD}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(do_verification) - { - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/run_gem_add_example.inc b/example/68_gemm_add/run_gem_add_example.inc index 5ee3ca3318..7a718ae93e 100644 --- a/example/68_gemm_add/run_gem_add_example.inc +++ b/example/68_gemm_add/run_gem_add_example.inc @@ -4,7 +4,7 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config { using namespace ck::literals; - auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -29,7 +29,7 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - switch(init_method) + switch(config.init_method) { case 0: break; case 1: @@ -60,6 +60,7 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config // do GEMM auto device_op = DeviceOpInstance{}; auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), @@ -83,7 +84,7 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = @@ -98,7 +99,7 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - if(do_verification) + if(config.do_verification) { Tensor c_m_n({M, N}); @@ -138,6 +139,5 @@ bool run_gemm_add_example(int argc, char* argv[]) ProblemSize problem_size; ExecutionConfig config; - return !parse_cmd_args(argc, argv, problem_size, config) || - run_gemm_add_multiply(problem_size, config); + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_add(problem_size, config); } diff --git a/example/68_gemm_add/run_gemm_add_example_v3.inc b/example/68_gemm_add/run_gemm_add_example_v3.inc new file mode 100644 index 0000000000..9e836276b0 --- /dev/null +++ b/example/68_gemm_add/run_gemm_add_example_v3.inc @@ -0,0 +1,144 @@ +#pragma once + +bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} + +bool run_gemm_add_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_add(problem_size, config); +} From 30378587a6261b75043717fb745b8e321945d0df Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 25 Jun 2025 13:24:52 +0000 Subject: [PATCH 27/62] Fixed clang format --- example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp | 2 +- example/68_gemm_add/gemm_add_xdl_fp16.cpp | 250 +----------------- 2 files changed, 4 insertions(+), 248 deletions(-) diff --git a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp index c44d124343..7844ae721d 100644 --- a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp @@ -88,4 +88,4 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ #include "run_gemm_add_example_v3.inc" -int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } \ No newline at end of file +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_xdl_fp16.cpp b/example/68_gemm_add/gemm_add_xdl_fp16.cpp index 77c3040171..fd86738260 100644 --- a/example/68_gemm_add/gemm_add_xdl_fp16.cpp +++ b/example/68_gemm_add/gemm_add_xdl_fp16.cpp @@ -1,70 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" - -struct Add -{ - template - __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(double& y, const double& x0, const double& x1) const - { - y = x0 + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const ck::half_t& x1) const - { - y = x0 + ck::type_convert(x1); - }; - - template <> - __host__ __device__ constexpr void - operator()(ck::half_t& y, const float& x0, const float& x1) const - { - y = ck::type_convert(x0 + x1); - }; - - template <> - __host__ __device__ constexpr void - operator()(ck::half_t& y, const float& x0, const ck::half_t& x1) const - { - y = ck::type_convert(x0) + x1; - }; - - template <> - __host__ __device__ constexpr void - operator()(ck::half_t& y, const ck::half_t& x0, const ck::half_t& x1) const - { - y = x0 + x1; - }; -}; +#include "common.hpp" template using S = ck::Sequence; @@ -140,187 +77,6 @@ using DeviceOpInstance = S<1, 32, 1, 8>, 8>; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; +#include "run_gem_add_example.inc" - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{StrideD}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(do_verification) - { - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } From 35aab35d960d86b8eed3276c5706a48bcf7fccab Mon Sep 17 00:00:00 2001 From: apoorva Date: Thu, 19 Jun 2025 14:06:32 +0000 Subject: [PATCH 28/62] Added bf16 wmma instance for add_relu --- .../gpu/gemm_add_relu.hpp | 62 +++++++++++++++- .../gpu/gemm_add_relu/CMakeLists.txt | 1 + ...16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 71 +++++++++++++++++++ 3 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp index 293e14b811..b79059de9a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { +#ifdef CK_USE_XDL void add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( std::vector>>&); +#elif defined(CK_USE_WMMA) +void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>&); +#endif + // GEMM + Add + Relu template > op_ptrs; +#ifdef CK_USE_XDL #if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -106,6 +136,36 @@ struct DeviceOperationInstanceFactory< } #endif +#elif defined(CK_USE_WMMA) + // For wmma ADataType must be same as BDatatype. + (CK_ENABLE_FP16) if constexpr(is_same_v && + is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_relu_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + +// For wmma ADataType must be same as BDatatype. +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif +#endif + return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 043bdab001..5797228e87 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -2,4 +2,5 @@ add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..48bd79b98f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 6f891831798e135f402b9e7e70bf6a688c30b930 Mon Sep 17 00:00:00 2001 From: apoorva Date: Mon, 23 Jun 2025 22:16:47 +0000 Subject: [PATCH 29/62] Added f16 wmma instance and corrected bf16 instance errors. --- .../gpu/gemm_add_relu/CMakeLists.txt | 2 + ...16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 32 ++++----- ...e_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 71 +++++++++++++++++++ 3 files changed, 89 insertions(+), 16 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 5797228e87..30cbadf3d8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -3,4 +3,6 @@ add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp ) + diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp index 48bd79b98f..cb32f2d118 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -20,38 +20,38 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[k, n], d0[m, n], d1[m, n] -using device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances = +using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances = std::tuple< // clang-format off // M/N/K padding - //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer|MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> // clang-format on >; -using device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< +using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< // clang-format off // M/N/K padding - //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -void add_device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( +void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( std::vector +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer|MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From cdaff7f210f0b9849fb397eda84aedfa3b7f399b Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 24 Jun 2025 20:17:46 +0000 Subject: [PATCH 30/62] Added instances to Cmake --- .../gpu/gemm_add_relu/CMakeLists.txt | 5 ++++- profiler/src/CMakeLists.txt | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 30cbadf3d8..1a4ed3a279 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -1,4 +1,4 @@ -# ONLY XDL_KERNELS +# XDL_AND_WMMA KERNELS add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -6,3 +6,6 @@ add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp ) + + +add_executable(device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 35a4e184a0..abb037484c 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -190,8 +190,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR - (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")) + (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" )) list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") @@ -205,6 +206,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_add_relu_instance) endif() endif() From 6a116fa958b93bb0efce452252015b8f0616404c Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 1 Jul 2025 10:21:25 +0000 Subject: [PATCH 31/62] Modified the template parameters to make the instances work. --- ...16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 8 +++---- ...e_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 24 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp index cb32f2d118..58a42ae223 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -28,7 +28,7 @@ using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generi //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -39,9 +39,9 @@ using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instan //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 0dd19c6916..0e0fa7497f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -28,41 +28,41 @@ using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_in //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8> // clang-format on >; -using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< +using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< // clang-format off // M/N/K padding //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> // clang-format on >; -void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( +void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( std::vector>>& instances) { add_device_operation_instances( instances, - device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances{}); add_device_operation_instances( - instances, device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances{}); + instances, device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); } } // namespace instance From bb7f6650f70562202762a77d43945f2bd0d7082d Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 1 Jul 2025 12:02:28 +0000 Subject: [PATCH 32/62] Fixed typo in profiler --- profiler/src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index abb037484c..a6a457ef42 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -206,7 +206,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_add_relu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) endif() endif() From f5843dd22be5dcb0fd4e41ec17c23d43b8dfed04 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 1 Jul 2025 12:37:46 +0000 Subject: [PATCH 33/62] Added v3 instances for gemm_add_relu --- .../gpu/gemm_add_relu/CMakeLists.txt | 4 +- ...16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 71 +++++++++++++++++++ ...3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 70 ++++++++++++++++++ 3 files changed, 143 insertions(+), 2 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 1a4ed3a279..8fb7f7fb72 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -4,8 +4,8 @@ add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp ) - -add_executable(device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..35c373a0e7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances< + GemmDefault>{}); + add_device_operation_instances( + instances, + device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..794b7f0e3e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +static constexpr auto V1 = BlockGemmPipelineVersion::v1; +static constexpr auto V3 = BlockGemmPipelineVersion::v3; + +template + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| + //##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| + //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3> + // clang-format on + >; + +void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances< + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 6ec0ad2758eba9beb54b995882488b0c7eb3354d Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 1 Jul 2025 13:44:18 +0000 Subject: [PATCH 34/62] Added test for gemm_add_relu wmma instance --- .../gpu/gemm_add_relu.hpp | 10 +++--- test/gemm_add/CMakeLists.txt | 5 +++ test/gemm_add/test_gemm_add_relu_wmma.cpp | 34 +++++++++++++++++++ 3 files changed, 44 insertions(+), 5 deletions(-) create mode 100644 test/gemm_add/test_gemm_add_relu_wmma.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp index b79059de9a..2cc7cab5e6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp @@ -138,14 +138,14 @@ struct DeviceOperationInstanceFactory< #elif defined(CK_USE_WMMA) // For wmma ADataType must be same as BDatatype. - (CK_ENABLE_FP16) if constexpr(is_same_v && - is_same_v && - is_same_v && is_same_v) +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_gemm_add_relu_wmma_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( op_ptrs); } } @@ -159,7 +159,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { - add_device_gemm_add_relu_wmma_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( op_ptrs); } } diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index 18fc3ee8f8..2e50516082 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -39,3 +39,8 @@ add_gtest_executable(test_gemm_bilinear_wmma test_gemm_bilinear_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_bilinear_wmma PRIVATE utility device_gemm_bilinear_instance) endif() + +add_gtest_executable(test_gemm_add_relu_wmma test_gemm_add_relu_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu_wmma PRIVATE utility device_gemm_add_relu_instance) +endif() \ No newline at end of file diff --git a/test/gemm_add/test_gemm_add_relu_wmma.cpp b/test/gemm_add/test_gemm_add_relu_wmma.cpp new file mode 100644 index 0000000000..e1e304f70f --- /dev/null +++ b/test/gemm_add/test_gemm_add_relu_wmma.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_relu_impl.hpp" +#include "test_gemm_common.hpp" + +template +class TestGemmAddRelu : public TestGemmD0Common +{ + using ProfileCall = typename TestGemmD0Common::ProfileCall; + + ProfileCall GetImpl() override + { + return ck::profiler::profile_gemm_add_relu_impl< + typename TestGemmD0Common::ADataType, + typename TestGemmD0Common::BDataType, + typename TestGemmD0Common::AccDataType, + typename TestGemmD0Common::D0DataType, + typename TestGemmD0Common::EDataType, + typename TestGemmD0Common::ALayout, + typename TestGemmD0Common::BLayout, + typename TestGemmD0Common::D0Layout, + typename TestGemmD0Common::ELayout>; + } +}; + +using KernelTypes = + ::testing::Types, Row>, + std::tuple, Row>>; + +TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes); +TYPED_TEST(TestGemmAddRelu, Test_BF16FP16) { this->Run(); } From feca919c25b10ed020fb0f375f60232cb6c57b11 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 1 Jul 2025 15:06:32 +0000 Subject: [PATCH 35/62] Cleaned up the code. --- example/68_gemm_add/CMakeLists.txt | 21 ++++++++++--------- example/68_gemm_add/common.hpp | 10 ++++----- example/68_gemm_add/gemm_add_wmma_fp16.cpp | 11 ---------- example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp | 13 ------------ example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp | 15 ------------- 5 files changed, 15 insertions(+), 55 deletions(-) diff --git a/example/68_gemm_add/CMakeLists.txt b/example/68_gemm_add/CMakeLists.txt index 2cf152c893..5bd7d73a92 100644 --- a/example/68_gemm_add/CMakeLists.txt +++ b/example/68_gemm_add/CMakeLists.txt @@ -1,18 +1,19 @@ +add_custom_target(example_gemm_add_xdl) + +add_library(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) +add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) + +add_library(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) +add_example_executable(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) + + +add_custom_target(example_gemm_add_wmma) add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp) add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp) add_example_executable(example_gemm_add_wmma_v3_fp16 gemm_add_wmma_v3_fp16.cpp) add_example_executable(example_gemm_add_wmma_v3_bf16 gemm_add_wmma_v3_bf16.cpp) -add_custom_target(example_gemm_add_xdl) -set_source_files_properties(example_gemm_add_xdl_fp16/gemm_add_xdl_fp16.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -add_library(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) -add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) -add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_fp16) - -set_source_files_properties(example_gemm_add_xdl_bf16/gemm_add_xdl_bf16.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -add_library(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) -add_example_executable(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) -add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_bf16) + diff --git a/example/68_gemm_add/common.hpp b/example/68_gemm_add/common.hpp index 4435503e6b..eab37e4132 100644 --- a/example/68_gemm_add/common.hpp +++ b/example/68_gemm_add/common.hpp @@ -13,17 +13,11 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#ifndef CK_USE_XDL #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" -#endif -#ifndef CK_USE_MULTIPLE_D_WMMA #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" -#endif -#ifndef CK_USE_WMMA_V3 #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" -#endif #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/utility/data_type.hpp" @@ -48,6 +42,10 @@ using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using Row_Tuple = ck::Tuple; +using F16_Tuple = ck::Tuple; +using BF16_Tuple = ck::Tuple; + struct ProblemSize final { ck::index_t M = 3840; diff --git a/example/68_gemm_add/gemm_add_wmma_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_fp16.cpp index 3a6d40ea49..3aa25bb471 100644 --- a/example/68_gemm_add/gemm_add_wmma_fp16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_fp16.cpp @@ -3,17 +3,6 @@ #include "common.hpp" -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using ADataType = F16; using BDataType = F16; using AccDataType = F32; diff --git a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp index 7a4204e12d..2a3641defc 100644 --- a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp @@ -3,19 +3,6 @@ #include "common.hpp" -template -using S = ck::Sequence; - -using BF16 = ck::bhalf_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using BF16_Tuple = ck::Tuple; - using ADataType = BF16; using BDataType = BF16; using AccDataType = F32; diff --git a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp index 7844ae721d..c98fc4b39e 100644 --- a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp @@ -3,19 +3,6 @@ #include "common.hpp" -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using F16_Tuple = ck::Tuple; - using ADataType = F16; using BDataType = F16; using AccDataType = F32; @@ -24,8 +11,6 @@ using DDataType = F16; using DsDataType = F16_Tuple; using EDataType = F16; -using Row_Tuple = ck::Tuple; - using ALayout = Row; using BLayout = Row; using DLayout = Row; From ba9c637c0beebcc34b8dc7eeee0626512d7f6b41 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 2 Jul 2025 14:12:52 +0000 Subject: [PATCH 36/62] Added examples for gemm_add_relu --- example/69_gemm_add_relu/CMakeLists.txt | 21 +++ example/69_gemm_add_relu/common.hpp | 114 ++++++++++++++ .../gemm_add_relu_wmma_bf16.cpp | 72 +++++++++ .../gemm_add_relu_wmma_fp16.cpp | 72 +++++++++ .../gemm_add_relu_wmma_v3_bf16.cpp | 78 ++++++++++ .../gemm_add_relu_wmma_v3_fp16.cpp | 76 +++++++++ .../gemm_add_relu_xdl_bf16.cpp | 82 ++++++++++ .../gemm_add_relu_xdl_fp16.cpp | 82 ++++++++++ .../run_gem_add_relu_example.inc | 144 +++++++++++++++++ .../run_gemm_add_relu_example_v3.inc | 145 ++++++++++++++++++ 10 files changed, 886 insertions(+) create mode 100644 example/69_gemm_add_relu/CMakeLists.txt create mode 100644 example/69_gemm_add_relu/common.hpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp create mode 100644 example/69_gemm_add_relu/run_gem_add_relu_example.inc create mode 100644 example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc diff --git a/example/69_gemm_add_relu/CMakeLists.txt b/example/69_gemm_add_relu/CMakeLists.txt new file mode 100644 index 0000000000..936e9acea3 --- /dev/null +++ b/example/69_gemm_add_relu/CMakeLists.txt @@ -0,0 +1,21 @@ +add_custom_target(example_gemm_add_relu_xdl) + +add_library(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp) +add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp) + +add_library(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp) +add_example_executable(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp) + + +add_custom_target(example_gemm_add_relu_wmma) +add_example_executable(example_gemm_add_relu_wmma_bf16 gemm_add_relu_wmma_bf16.cpp) + +add_example_executable(example_gemm_add_relu_wmma_fp16 gemm_add_relu_wmma_fp16.cpp) + +add_example_executable(example_gemm_add_relu_wmma_v3_fp16 gemm_add_relu_wmma_v3_fp16.cpp) +add_example_executable(example_gemm_add_relu_wmma_v3_bf16 gemm_add_relu_wmma_v3_bf16.cpp) + + + + + diff --git a/example/69_gemm_add_relu/common.hpp b/example/69_gemm_add_relu/common.hpp new file mode 100644 index 0000000000..151653e515 --- /dev/null +++ b/example/69_gemm_add_relu/common.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Row_Tuple = ck::Tuple; +using F16_Tuple = ck::Tuple; +using BF16_Tuple = ck::Tuple; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; +}; +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideD = std::stoi(argv[9]); + problem_size.StrideE = std::stoi(argv[10]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD," + "StrideE" + << std::endl; + return false; + } + + return true; +} diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp new file mode 100644 index 0000000000..34b791573a --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; + +// clang-format on + +#include "run_gem_add_relu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp new file mode 100644 index 0000000000..8459e67cb4 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; + +// clang-format on + +#include "run_gem_add_relu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp new file mode 100644 index 0000000000..84a40ab3f5 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using DsDataType = BF16_Tuple; +using EDataType = BF16; + +using Row_Tuple = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + BF16, + BF16, + BF16_Tuple, + BF16, + F32, + F32, + PassThrough, + PassThrough, + Add, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +// clang-format on + +#include "run_gemm_add_relu_example_v3.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp new file mode 100644 index 0000000000..cca26d6212 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = F16_Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + F16, + F16, + F16_Tuple, + F16, + F32, + F32, + PassThrough, + PassThrough, + Add, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +// clang-format on + +#include "run_gemm_add_relu_example_v3.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp new file mode 100644 index 0000000000..824b1c2f10 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_gem_add_relu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp new file mode 100644 index 0000000000..ef8c4cdcf8 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_gem_add_relu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/69_gemm_add_relu/run_gem_add_relu_example.inc b/example/69_gemm_add_relu/run_gem_add_relu_example.inc new file mode 100644 index 0000000000..9d17f7863a --- /dev/null +++ b/example/69_gemm_add_relu/run_gem_add_relu_example.inc @@ -0,0 +1,144 @@ +#pragma once + +bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} + +bool run_gemm_add_relu_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || + run_gemm_add_relu(problem_size, config); +} diff --git a/example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc b/example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc new file mode 100644 index 0000000000..3c787421eb --- /dev/null +++ b/example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc @@ -0,0 +1,145 @@ +#pragma once + +bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} + +bool run_gemm_add_relu_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || + run_gemm_add_relu(problem_size, config); +} From 5c491e7a4bf2e2ad1b7f1d419b665cdb367cb9a0 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 2 Jul 2025 18:04:10 +0000 Subject: [PATCH 37/62] Fixing typo to resolve build errors. --- example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp | 2 +- example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp index 84a40ab3f5..85b2a65b20 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp @@ -38,7 +38,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ F32, PassThrough, PassThrough, - Add, + AddRelu, GemmSpec, 128, 128, diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp index cca26d6212..1a172b20a6 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp @@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ F32, PassThrough, PassThrough, - Add, + AddRelu, GemmSpec, 128, 128, From 8a5bb256f7196a7328821bc80719a91e2837a689 Mon Sep 17 00:00:00 2001 From: apoorva Date: Mon, 7 Jul 2025 09:41:00 +0000 Subject: [PATCH 38/62] Fixes applied to fix the precision loss. --- example/68_gemm_add/CMakeLists.txt | 4 ---- example/68_gemm_add/run_gem_add_example.inc | 6 +++--- example/68_gemm_add/run_gemm_add_example_v3.inc | 6 +++--- .../gpu/element/binary_element_wise_operation.hpp | 2 +- .../{test_gemm_add_xdl.hpp => test_gemm_add_xdl.cpp} | 0 5 files changed, 7 insertions(+), 11 deletions(-) rename test/gemm_add/{test_gemm_add_xdl.hpp => test_gemm_add_xdl.cpp} (100%) diff --git a/example/68_gemm_add/CMakeLists.txt b/example/68_gemm_add/CMakeLists.txt index 5bd7d73a92..f64a291b97 100644 --- a/example/68_gemm_add/CMakeLists.txt +++ b/example/68_gemm_add/CMakeLists.txt @@ -1,12 +1,8 @@ add_custom_target(example_gemm_add_xdl) -add_library(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) - -add_library(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) add_example_executable(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) - add_custom_target(example_gemm_add_wmma) add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp) add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp) diff --git a/example/68_gemm_add/run_gem_add_example.inc b/example/68_gemm_add/run_gem_add_example.inc index 7a718ae93e..3a713a0c3d 100644 --- a/example/68_gemm_add/run_gem_add_example.inc +++ b/example/68_gemm_add/run_gem_add_example.inc @@ -128,10 +128,10 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); } - return 0; + return true; } bool run_gemm_add_example(int argc, char* argv[]) @@ -139,5 +139,5 @@ bool run_gemm_add_example(int argc, char* argv[]) ProblemSize problem_size; ExecutionConfig config; - return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_add(problem_size, config); + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_add(problem_size, config); } diff --git a/example/68_gemm_add/run_gemm_add_example_v3.inc b/example/68_gemm_add/run_gemm_add_example_v3.inc index 9e836276b0..b99b889416 100644 --- a/example/68_gemm_add/run_gemm_add_example_v3.inc +++ b/example/68_gemm_add/run_gemm_add_example_v3.inc @@ -129,10 +129,10 @@ bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); } - return 0; + return true; } bool run_gemm_add_example(int argc, char* argv[]) @@ -140,5 +140,5 @@ bool run_gemm_add_example(int argc, char* argv[]) ProblemSize problem_size; ExecutionConfig config; - return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm_add(problem_size, config); + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_add(problem_size, config); } diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 34c76b89e4..35eb7841cc 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -47,7 +47,7 @@ struct Add __host__ __device__ constexpr void operator()(half_t& y, const float& x0, const half_t& x1) const { - y = type_convert(x0) + x1; + y = x0 + type_convert(x1); }; template <> diff --git a/test/gemm_add/test_gemm_add_xdl.hpp b/test/gemm_add/test_gemm_add_xdl.cpp similarity index 100% rename from test/gemm_add/test_gemm_add_xdl.hpp rename to test/gemm_add/test_gemm_add_xdl.cpp From 86ca6b827d32163f9444dd36f648438df60c07cc Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 11:23:33 +0000 Subject: [PATCH 39/62] Removed the old wmma instances. --- .../gemm_add_relu_wmma_bf16.cpp | 72 --------- .../gemm_add_relu_wmma_fp16.cpp | 72 --------- .../run_gem_add_relu_example.inc | 144 ------------------ ...16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 71 --------- ...e_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 71 --------- 5 files changed, 430 deletions(-) delete mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp delete mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp delete mode 100644 example/69_gemm_add_relu/run_gem_add_relu_example.inc delete mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp deleted file mode 100644 index 34b791573a..0000000000 --- a/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -using ADataType = BF16; -using BDataType = BF16; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = BF16; -using EDataType = BF16; - -using ALayout = Row; -using BLayout = Col; -using DLayout = Row; -using ELayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CDEElementOp = AddRelu; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< - ALayout, - BLayout, - ck::Tuple, - ELayout, - ADataType, - BDataType, - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 2, // Prefetch stage - 128, // BlockSize - 128, // MPerBlock - 64, // NPerBlock - 64, // KPerBlock - 8, // K1 - 16, // MPerWmma - 16, // NPerWmma - 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave - 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, // C shuffle (M Repeat) Per store - 1, // C shuffle (N Repeat) Per store - S<1, 32, 1, 4>, - 8>; - -// clang-format on - -#include "run_gem_add_relu_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp deleted file mode 100644 index 8459e67cb4..0000000000 --- a/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F16; -using EDataType = F16; - -using ALayout = Row; -using BLayout = Col; -using DLayout = Row; -using ELayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CDEElementOp = AddRelu; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< - ALayout, - BLayout, - ck::Tuple, - ELayout, - ADataType, - BDataType, - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 2, // Prefetch stage - 128, // BlockSize - 128, // MPerBlock - 64, // NPerBlock - 64, // KPerBlock - 8, // K1 - 16, // MPerWmma - 16, // NPerWmma - 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave - 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, // C shuffle (M Repeat) Per store - 1, // C shuffle (N Repeat) Per store - S<1, 32, 1, 4>, - 8>; - -// clang-format on - -#include "run_gem_add_relu_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/run_gem_add_relu_example.inc b/example/69_gemm_add_relu/run_gem_add_relu_example.inc deleted file mode 100644 index 9d17f7863a..0000000000 --- a/example/69_gemm_add_relu/run_gem_add_relu_example.inc +++ /dev/null @@ -1,144 +0,0 @@ -#pragma once - -bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - using namespace ck::literals; - - auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(config.init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - - auto argument = - device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{StrideD}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << device_op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(config.do_verification) - { - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} - -bool run_gemm_add_relu_example(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - return !parse_cmd_args(argc, argv, problem_size, config) || - run_gemm_add_relu(problem_size, config); -} diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp deleted file mode 100644 index 58a42ae223..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ /dev/null @@ -1,71 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" -#include "ck/utility/sequence.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = ck::Sequence; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// e = elementwise((a * b), d0, d1) -// outout: e[m, n] -// input: a[m, k], b[k, n], d0[m, n], d1[m, n] -using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances = - std::tuple< - // clang-format off - // M/N/K padding - //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer|MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| - //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< - // clang-format off - // M/N/K padding - //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| - //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; - -void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances{}); - add_device_operation_instances( - instances, device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp deleted file mode 100644 index 0e0fa7497f..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ /dev/null @@ -1,71 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" -#include "ck/utility/sequence.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = ck::Sequence; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// e = elementwise((a * b), d0, d1) -// outout: e[m, n] -// input: a[m, k], b[k, n], d0[m, n], d1[m, n] -using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances = - std::tuple< - // clang-format off - // M/N/K padding - //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer|MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| - //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< - // clang-format off - // M/N/K padding - //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | Wmma|Wmma| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| - //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 512, 64, 512, 32, 8, 16, 16, 4, 2, S<4, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; - -void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances{}); - add_device_operation_instances( - instances, device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck From 9b64da2298d9347da1a5ae81dfa0cb56752bd70b Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 11:26:01 +0000 Subject: [PATCH 40/62] Added wrapper and renamed the wmma_v3 instances --- example/69_gemm_add_relu/CMakeLists.txt | 4 -- ...3_bf16.cpp => gemm_add_relu_wmma_bf16.cpp} | 2 +- ...3_fp16.cpp => gemm_add_relu_wmma_fp16.cpp} | 2 +- ...e_v3.inc => run_gemm_add_relu_example.inc} | 0 .../gpu/gemm_add_relu.hpp | 47 +++++++++---------- .../gpu/gemm_add_relu/CMakeLists.txt | 3 -- ...6_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp} | 0 ..._f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp} | 0 8 files changed, 25 insertions(+), 33 deletions(-) rename example/69_gemm_add_relu/{gemm_add_relu_wmma_v3_bf16.cpp => gemm_add_relu_wmma_bf16.cpp} (97%) rename example/69_gemm_add_relu/{gemm_add_relu_wmma_v3_fp16.cpp => gemm_add_relu_wmma_fp16.cpp} (96%) rename example/69_gemm_add_relu/{run_gemm_add_relu_example_v3.inc => run_gemm_add_relu_example.inc} (100%) rename library/src/tensor_operation_instance/gpu/gemm_add_relu/{device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp => device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/gemm_add_relu/{device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp => device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp} (100%) diff --git a/example/69_gemm_add_relu/CMakeLists.txt b/example/69_gemm_add_relu/CMakeLists.txt index 936e9acea3..fe9d783755 100644 --- a/example/69_gemm_add_relu/CMakeLists.txt +++ b/example/69_gemm_add_relu/CMakeLists.txt @@ -1,9 +1,7 @@ add_custom_target(example_gemm_add_relu_xdl) -add_library(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp) add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp) -add_library(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp) add_example_executable(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp) @@ -12,8 +10,6 @@ add_example_executable(example_gemm_add_relu_wmma_bf16 gemm_add_relu_wmma_bf16.c add_example_executable(example_gemm_add_relu_wmma_fp16 gemm_add_relu_wmma_fp16.cpp) -add_example_executable(example_gemm_add_relu_wmma_v3_fp16 gemm_add_relu_wmma_v3_fp16.cpp) -add_example_executable(example_gemm_add_relu_wmma_v3_bf16 gemm_add_relu_wmma_v3_bf16.cpp) diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp similarity index 97% rename from example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp rename to example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp index 85b2a65b20..c91f2220aa 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp @@ -73,6 +73,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ // clang-format on -#include "run_gemm_add_relu_example_v3.inc" +#include "run_gemm_add_relu_example.inc" int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp similarity index 96% rename from example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp rename to example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp index 1a172b20a6..1d7febe66e 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp @@ -71,6 +71,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ // clang-format on -#include "run_gemm_add_relu_example_v3.inc" +#include "run_gemm_add_relu_example.inc" int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc b/example/69_gemm_add_relu/run_gemm_add_relu_example.inc similarity index 100% rename from example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc rename to example/69_gemm_add_relu/run_gemm_add_relu_example.inc diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp index 2cc7cab5e6..0792e3eb89 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp @@ -45,30 +45,30 @@ void add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instan #elif defined(CK_USE_WMMA) void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( - std::vector>>&); + std::vector>>&); void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( - std::vector>>&); + std::vector>>&); #endif // GEMM + Add + Relu @@ -137,7 +137,7 @@ struct DeviceOperationInstanceFactory< #endif #elif defined(CK_USE_WMMA) - // For wmma ADataType must be same as BDatatype. + #if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -151,7 +151,6 @@ struct DeviceOperationInstanceFactory< } #endif -// For wmma ADataType must be same as BDatatype. #if defined(CK_ENABLE_BF16) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 8fb7f7fb72..28e0ccb33d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -4,8 +4,5 @@ add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp - device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp - device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp ) - diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp From 669befb25ab2bbcbdbb9e4140e1b4736302d7869 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 12:06:38 +0000 Subject: [PATCH 41/62] Updated copyrights and added wrappers. --- .../gemm_add_relu_wmma_bf16.cpp | 2 +- .../gemm_add_relu_wmma_fp16.cpp | 2 +- .../gpu/gemm_add_relu.hpp | 44 +++++++++---------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp index c91f2220aa..abb33ad6d3 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp index 1d7febe66e..b71a5affdb 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp index 0792e3eb89..c039f94021 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp @@ -81,29 +81,29 @@ template struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceGemmMultipleD, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - PassThrough, - PassThrough, - AddRelu>> + ck::tensor_operation::device::DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>> { - using DeviceOp = DeviceGemmMultipleD, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - PassThrough, - PassThrough, - AddRelu>; + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>; static auto GetInstances() { From bdfdb0c11ec93c819d734afcc52945d8d43ffcd6 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 12:17:20 +0000 Subject: [PATCH 42/62] Fixes applied according to review comments --- example/69_gemm_add_relu/CMakeLists.txt | 12 +++++------- test/gemm_add/test_gemm_add_relu_wmma.cpp | 5 ++--- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/example/69_gemm_add_relu/CMakeLists.txt b/example/69_gemm_add_relu/CMakeLists.txt index fe9d783755..9ab3ef5a45 100644 --- a/example/69_gemm_add_relu/CMakeLists.txt +++ b/example/69_gemm_add_relu/CMakeLists.txt @@ -1,17 +1,15 @@ add_custom_target(example_gemm_add_relu_xdl) add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_relu_xdl example_gemm_add_relu_xdl_fp16) add_example_executable(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp) - +add_example_dependencies(example_gemm_add_relu_xdl example_gemm_add_relu_xdl_bf16) add_custom_target(example_gemm_add_relu_wmma) + add_example_executable(example_gemm_add_relu_wmma_bf16 gemm_add_relu_wmma_bf16.cpp) +add_example_dependencies(example_gemm_add_relu_wmma example_gemm_add_relu_wmma_bf16) add_example_executable(example_gemm_add_relu_wmma_fp16 gemm_add_relu_wmma_fp16.cpp) - - - - - - +add_example_dependencies(example_gemm_add_relu_wmma example_gemm_add_relu_wmma_fp16) diff --git a/test/gemm_add/test_gemm_add_relu_wmma.cpp b/test/gemm_add/test_gemm_add_relu_wmma.cpp index e1e304f70f..8ff2f4217b 100644 --- a/test/gemm_add/test_gemm_add_relu_wmma.cpp +++ b/test/gemm_add/test_gemm_add_relu_wmma.cpp @@ -26,9 +26,8 @@ class TestGemmAddRelu : public TestGemmD0Common } }; -using KernelTypes = - ::testing::Types, Row>, - std::tuple, Row>>; +using KernelTypes = ::testing::Types, + std::tuple>; TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes); TYPED_TEST(TestGemmAddRelu, Test_BF16FP16) { this->Run(); } From d3a26e5ceedc753f8e5f33d33f9f55c385d81da0 Mon Sep 17 00:00:00 2001 From: Apoorva Kalyani Date: Tue, 8 Jul 2025 12:20:24 +0000 Subject: [PATCH 43/62] Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Robin Voetter --- profiler/src/CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index a6a457ef42..9c37316292 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -192,7 +192,6 @@ endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" )) list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") @@ -203,10 +202,10 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance) - list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) endif() endif() From 84b0b324cf6de3bb29a8aac17e28b58191b582fa Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 13:51:48 +0000 Subject: [PATCH 44/62] Removed the old wmma instances. --- .../gpu/gemm_add/CMakeLists.txt | 3 +- ...16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp | 70 ------------------- ...e_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp | 69 ------------------ 3 files changed, 1 insertion(+), 141 deletions(-) delete mode 100644 library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index 2715077268..3e71429bdd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -2,8 +2,7 @@ add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp - device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp - device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp deleted file mode 100644 index ed8a8d219b..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp +++ /dev/null @@ -1,70 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" -#include "ck/utility/sequence.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = ck::Sequence; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// e = elementwise((a * b), d0, d1) -// outout: e[m, n] -// input: a[m, k], b[k, n], d0[m, n], d1[m, n] -using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances = std::tuple< - // clang-format off - // M/N/K padding - //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NwmmaPerWave| _MBlock_MWaveMPerwmma| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerwmma| _NWaveNPerwmma| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8> - // clang-format on - >; - -using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< - // clang-format off - // M/N/K padding - //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MwmmaPerWave| NwmmaPerWave| _MBlock_MWaveMPerwmma| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerwmma| _NWaveNPerwmma| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, BF16, BF16, F32, F32, BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; - -void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_generic_instances{}); - add_device_operation_instances( - instances, device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp deleted file mode 100644 index d1aa066792..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" -#include "ck/utility/sequence.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using S = ck::Sequence; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// e = elementwise((a * b), d0, d1) -// outout: e[m, n] -// input: a[m, k], b[k, n], d0[m, n], d1[m, n] -using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< - // clang-format off - // M/N/K padding - //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MwmmaPerWave| NwmmaPerWave| _MBlock_MWaveMPerwmma| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerwmma| _NWaveNPerwmma| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8> - // clang-format on - >; - -using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< - // clang-format off - // M/N/K padding - //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MwmmaPerWave| NwmmaPerWave| _MBlock_MWaveMPerwmma| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerwmma| _NWaveNPerwmma| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32,F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; - -void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instances{}); - add_device_operation_instances( - instances, device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck From 516d1f519461ba59d92dd21ca438157370d12140 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 13:55:24 +0000 Subject: [PATCH 45/62] Updated wrapper for the v3 instances --- .../tensor_operation_instance/gpu/gemm_add/CMakeLists.txt | 4 ++-- ...a_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp} | 0 ..._wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp} | 0 test/gemm_add/test_gemm_add_wmma.cpp | 5 ++--- 4 files changed, 4 insertions(+), 5 deletions(-) rename library/src/tensor_operation_instance/gpu/gemm_add/{device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp => device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/gemm_add/{device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp => device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp} (100%) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index 3e71429bdd..478e9a8ab8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -3,6 +3,6 @@ add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp - device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp - device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_v3_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/test/gemm_add/test_gemm_add_wmma.cpp b/test/gemm_add/test_gemm_add_wmma.cpp index f4d29311b8..5fc6738ca9 100644 --- a/test/gemm_add/test_gemm_add_wmma.cpp +++ b/test/gemm_add/test_gemm_add_wmma.cpp @@ -25,9 +25,8 @@ class TestGemmAdd : public TestGemmD0Common } }; -using KernelTypes = - ::testing::Types, Row>, - std::tuple, Row>>; +using KernelTypes = ::testing::Types, + std::tuple>; TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); TYPED_TEST(TestGemmAdd, Test_BF16FP16) { this->Run(); } From e59d281b1081c88700de55ae670a051029be647c Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 13:57:55 +0000 Subject: [PATCH 46/62] removed the old wmma examples --- _deps/gtest-src | 1 + example/68_gemm_add/gemm_add_wmma_bf16.cpp | 72 ---------- example/68_gemm_add/gemm_add_wmma_fp16.cpp | 72 ---------- example/68_gemm_add/run_gem_add_example.inc | 143 -------------------- 4 files changed, 1 insertion(+), 287 deletions(-) create mode 160000 _deps/gtest-src delete mode 100644 example/68_gemm_add/gemm_add_wmma_bf16.cpp delete mode 100644 example/68_gemm_add/gemm_add_wmma_fp16.cpp delete mode 100644 example/68_gemm_add/run_gem_add_example.inc diff --git a/_deps/gtest-src b/_deps/gtest-src new file mode 160000 index 0000000000..f8d7d77c06 --- /dev/null +++ b/_deps/gtest-src @@ -0,0 +1 @@ +Subproject commit f8d7d77c06936315286eb55f8de22cd23c188571 diff --git a/example/68_gemm_add/gemm_add_wmma_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_bf16.cpp deleted file mode 100644 index bf9fc119f7..0000000000 --- a/example/68_gemm_add/gemm_add_wmma_bf16.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -using ADataType = BF16; -using BDataType = BF16; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = BF16; -using EDataType = BF16; - -using ALayout = Row; -using BLayout = Col; -using DLayout = Row; -using ELayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CDEElementOp = Add; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< - ALayout, - BLayout, - ck::Tuple, - ELayout, - ADataType, - BDataType, - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 2, // Prefetch stage - 128, // BlockSize - 128, // MPerBlock - 64, // NPerBlock - 64, // KPerBlock - 8, // K1 - 16, // MPerWmma - 16, // NPerWmma - 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave - 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, // C shuffle (M Repeat) Per store - 1, // C shuffle (N Repeat) Per store - S<1, 32, 1, 4>, - 8>; - -// clang-format on - -#include "run_gem_add_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_wmma_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_fp16.cpp deleted file mode 100644 index 3aa25bb471..0000000000 --- a/example/68_gemm_add/gemm_add_wmma_fp16.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F16; -using EDataType = F16; - -using ALayout = Row; -using BLayout = Col; -using DLayout = Row; -using ELayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CDEElementOp = Add; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< - ALayout, - BLayout, - ck::Tuple, - ELayout, - ADataType, - BDataType, - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 2, // Prefetch stage - 128, // BlockSize - 128, // MPerBlock - 64, // NPerBlock - 64, // KPerBlock - 8, // K1 - 16, // MPerWmma - 16, // NPerWmma - 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave - 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 32, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, // C shuffle (M Repeat) Per store - 1, // C shuffle (N Repeat) Per store - S<1, 32, 1, 4>, - 8>; - -// clang-format on - -#include "run_gem_add_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/run_gem_add_example.inc b/example/68_gemm_add/run_gem_add_example.inc deleted file mode 100644 index 3a713a0c3d..0000000000 --- a/example/68_gemm_add/run_gem_add_example.inc +++ /dev/null @@ -1,143 +0,0 @@ -#pragma once - -bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - using namespace ck::literals; - - auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(config.init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - - auto argument = - device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{StrideD}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << device_op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(config.do_verification) - { - Tensor c_m_n({M, N}); - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); - } - - return true; -} - -bool run_gemm_add_example(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_add(problem_size, config); -} From 566e472f4cdf8ab1a3ea4863415d573b6d1c4669 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 14:50:25 +0000 Subject: [PATCH 47/62] Renamed the v3 instances --- example/68_gemm_add/CMakeLists.txt | 11 +++++++++-- ...mm_add_wmma_v3_bf16.cpp => gemm_add_wmma_bf16.cpp} | 0 ...mm_add_wmma_v3_fp16.cpp => gemm_add_wmma_fp16.cpp} | 0 ...mm_add_example_v3.inc => run_gemm_add_example.inc} | 0 4 files changed, 9 insertions(+), 2 deletions(-) rename example/68_gemm_add/{gemm_add_wmma_v3_bf16.cpp => gemm_add_wmma_bf16.cpp} (100%) rename example/68_gemm_add/{gemm_add_wmma_v3_fp16.cpp => gemm_add_wmma_fp16.cpp} (100%) rename example/68_gemm_add/{run_gemm_add_example_v3.inc => run_gemm_add_example.inc} (100%) diff --git a/example/68_gemm_add/CMakeLists.txt b/example/68_gemm_add/CMakeLists.txt index f64a291b97..af091d32e4 100644 --- a/example/68_gemm_add/CMakeLists.txt +++ b/example/68_gemm_add/CMakeLists.txt @@ -1,13 +1,20 @@ add_custom_target(example_gemm_add_xdl) add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_fp16) + + add_example_executable(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp) +add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_bf16) add_custom_target(example_gemm_add_wmma) + add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp) +add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_bf16) + add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp) -add_example_executable(example_gemm_add_wmma_v3_fp16 gemm_add_wmma_v3_fp16.cpp) -add_example_executable(example_gemm_add_wmma_v3_bf16 gemm_add_wmma_v3_bf16.cpp) +add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_fp16) + diff --git a/example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_bf16.cpp similarity index 100% rename from example/68_gemm_add/gemm_add_wmma_v3_bf16.cpp rename to example/68_gemm_add/gemm_add_wmma_bf16.cpp diff --git a/example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_fp16.cpp similarity index 100% rename from example/68_gemm_add/gemm_add_wmma_v3_fp16.cpp rename to example/68_gemm_add/gemm_add_wmma_fp16.cpp diff --git a/example/68_gemm_add/run_gemm_add_example_v3.inc b/example/68_gemm_add/run_gemm_add_example.inc similarity index 100% rename from example/68_gemm_add/run_gemm_add_example_v3.inc rename to example/68_gemm_add/run_gemm_add_example.inc From 965501086c76b1bf01504f2c67041ca923e57134 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 14:53:47 +0000 Subject: [PATCH 48/62] Deleted the gtest file added by mistake. --- _deps/gtest-src | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _deps/gtest-src diff --git a/_deps/gtest-src b/_deps/gtest-src deleted file mode 160000 index f8d7d77c06..0000000000 --- a/_deps/gtest-src +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f8d7d77c06936315286eb55f8de22cd23c188571 From 536f86661d6e119c247cf38445b9d49f8375459a Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 14:57:12 +0000 Subject: [PATCH 49/62] Updated thge profiler with wrapper --- profiler/include/profiler/profile_gemm_add_relu_impl.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp index 5d79a98c11..dcefcee299 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -93,7 +93,7 @@ bool profile_gemm_add_relu_impl(int do_verification, const auto b_element_op = BElementOp{}; const auto cde_element_op = CDEElementOp{}; - using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDSplitK< ALayout, BLayout, ck::Tuple, From 13efcc6fe1488f23ee7428df20c8aad2148f9f02 Mon Sep 17 00:00:00 2001 From: apoorva Date: Tue, 8 Jul 2025 18:30:01 +0000 Subject: [PATCH 50/62] Fixed test errors. --- profiler/include/profiler/profile_gemm_add_relu_impl.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp index dcefcee299..65b1925f61 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp @@ -173,6 +173,7 @@ bool profile_gemm_add_relu_impl(int do_verification, StrideB, std::array{StrideD0}, StrideE, + 1, a_element_op, b_element_op, cde_element_op); From 55299c924e4459b99db1f3b3d68faa21b19c1d87 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 07:33:14 +0000 Subject: [PATCH 51/62] Fixed the review comments --- example/69_gemm_add_relu/common.hpp | 2 +- .../gemm_add_relu_xdl_bf16.cpp | 2 +- .../gemm_add_relu_xdl_fp16.cpp | 2 +- .../gpu/gemm_add_relu.hpp | 122 ++++++++++++++---- 4 files changed, 98 insertions(+), 30 deletions(-) diff --git a/example/69_gemm_add_relu/common.hpp b/example/69_gemm_add_relu/common.hpp index 151653e515..311cbb2dfb 100644 --- a/example/69_gemm_add_relu/common.hpp +++ b/example/69_gemm_add_relu/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp index 824b1c2f10..6fcafb1cc0 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp index ef8c4cdcf8..6cd0ef4d41 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp index c039f94021..7fed931d5b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -109,32 +109,9 @@ struct DeviceOperationInstanceFactory< { std::vector> op_ptrs; -#ifdef CK_USE_XDL -#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( - op_ptrs); - } - } -#endif - -#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16) - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( - op_ptrs); - } - } -#endif +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleDSplitK with AddRelu at the moment +#endif // CK_USE_XDL #elif defined(CK_USE_WMMA) @@ -169,6 +146,97 @@ struct DeviceOperationInstanceFactory< } }; +// GEMM + Add + Relu +// DeviceGemmMultipleD specialization +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>> +{ + using DeviceOp = DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_USE_XDL +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleDSplitK instances + using Wrapper = DeviceGemmMultipleDSplitKWrapper, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA + + return op_ptrs; + } +}; } // namespace instance } // namespace device } // namespace tensor_operation From 32125077e702c9cc7038dcac4f8e251ddfe9b732 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 08:03:09 +0000 Subject: [PATCH 52/62] Fixed the if condition MACROS. --- .../ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp index 7fed931d5b..51023340fd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp @@ -113,7 +113,7 @@ struct DeviceOperationInstanceFactory< // No XDL instances for DeviceGemmMultipleDSplitK with AddRelu at the moment #endif // CK_USE_XDL -#elif defined(CK_USE_WMMA) +#if defined(CK_USE_WMMA) #if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && From 21cb98546cca10c80fea142544c3f9d0937614f0 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 08:22:52 +0000 Subject: [PATCH 53/62] REVERTED THE PROFILER CHANGES --- profiler/include/profiler/profile_gemm_add_relu_impl.hpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp index 65b1925f61..8b0d4cd79d 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -93,7 +93,7 @@ bool profile_gemm_add_relu_impl(int do_verification, const auto b_element_op = BElementOp{}; const auto cde_element_op = CDEElementOp{}; - using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDSplitK< + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< ALayout, BLayout, ck::Tuple, @@ -173,7 +173,6 @@ bool profile_gemm_add_relu_impl(int do_verification, StrideB, std::array{StrideD0}, StrideE, - 1, a_element_op, b_element_op, cde_element_op); From e1374ea221b95f2ac61ee95feda5fad6896b3a27 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 08:25:30 +0000 Subject: [PATCH 54/62] Revert "REVERTED THE PROFILER CHANGES" This reverts commit 21cb98546cca10c80fea142544c3f9d0937614f0. --- profiler/include/profiler/profile_gemm_add_relu_impl.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp index 8b0d4cd79d..65b1925f61 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -93,7 +93,7 @@ bool profile_gemm_add_relu_impl(int do_verification, const auto b_element_op = BElementOp{}; const auto cde_element_op = CDEElementOp{}; - using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDSplitK< ALayout, BLayout, ck::Tuple, @@ -173,6 +173,7 @@ bool profile_gemm_add_relu_impl(int do_verification, StrideB, std::array{StrideD0}, StrideE, + 1, a_element_op, b_element_op, cde_element_op); From 9e3d87ea8a2c6072733caa28fc8ba1bec9d405a7 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 08:26:08 +0000 Subject: [PATCH 55/62] Revert "Fixed test errors." This reverts commit 13efcc6fe1488f23ee7428df20c8aad2148f9f02. --- profiler/include/profiler/profile_gemm_add_relu_impl.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp index 65b1925f61..dcefcee299 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp @@ -173,7 +173,6 @@ bool profile_gemm_add_relu_impl(int do_verification, StrideB, std::array{StrideD0}, StrideE, - 1, a_element_op, b_element_op, cde_element_op); From ea133bf303430cd8e368ecb838011ae29759c6b1 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 08:53:36 +0000 Subject: [PATCH 56/62] Revert "Updated thge profiler with wrapper" This reverts commit 536f86661d6e119c247cf38445b9d49f8375459a. --- profiler/include/profiler/profile_gemm_add_relu_impl.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp index dcefcee299..5d79a98c11 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -93,7 +93,7 @@ bool profile_gemm_add_relu_impl(int do_verification, const auto b_element_op = BElementOp{}; const auto cde_element_op = CDEElementOp{}; - using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDSplitK< + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< ALayout, BLayout, ck::Tuple, From 76f4bb0e60ab8eca646bbec90364d94cd958c91c Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 09:22:02 +0000 Subject: [PATCH 57/62] Added missing wrapper instances --- .../gpu/gemm_add.hpp | 198 ++++++++++++------ 1 file changed, 134 insertions(+), 64 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp index 4b7bdeb5af..3832a9f6d4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { + #ifdef CK_USE_XDL void add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( std::vector>>&); -void add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( +void add_device_gemm_add_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( std::vector>>&); + std::vector>>&); void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( - std::vector>>&); + std::vector>>&); #endif -// GEMM + Add + + +// GEMM + Add template struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceGemmMultipleD, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - PassThrough, - PassThrough, - Add>> + ck::tensor_operation::device::DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleDSplitK with Add at the moment +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(op_ptrs); + } + } +#endif + +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + +// GEMM + Add +// DeviceGemmMultipleD specialization +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>> { using DeviceOp = DeviceGemmMultipleD> op_ptrs; + #ifdef CK_USE_XDL #if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && @@ -130,40 +207,33 @@ struct DeviceOperationInstanceFactory< } } #endif -#elif defined(CK_USE_WMMA) -// TODO: -// here for WMMA, currently BDataType and ADataType must be the same -#if defined(CK_ENABLE_FP16) - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(op_ptrs); - } - } -#endif -#if defined(CK_ENABLE_BF16) - // TODO: - // here for WMMA, currently BDataType and ADataType must be the same - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleDSplitK instances + using Wrapper = DeviceGemmMultipleDSplitKWrapper, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) { - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances( - op_ptrs); - } + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); } -#endif -#endif +#endif // CK_USE_WMMA + return op_ptrs; } }; - } // namespace instance } // namespace device } // namespace tensor_operation From 2738ca5047cc1059f19d5f8c25132539ccf74a41 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 09:25:09 +0000 Subject: [PATCH 58/62] Updated copyrights. --- example/68_gemm_add/gemm_add_wmma_bf16.cpp | 4 ++-- example/68_gemm_add/gemm_add_wmma_fp16.cpp | 4 ++-- example/68_gemm_add/gemm_add_xdl_bf16.cpp | 2 +- example/68_gemm_add/gemm_add_xdl_fp16.cpp | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/example/68_gemm_add/gemm_add_wmma_bf16.cpp b/example/68_gemm_add/gemm_add_wmma_bf16.cpp index 2a3641defc..ba8b4f1f76 100644 --- a/example/68_gemm_add/gemm_add_wmma_bf16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -73,6 +73,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ // clang-format on -#include "run_gemm_add_example_v3.inc" +#include "run_gemm_add_example.inc" int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_wmma_fp16.cpp b/example/68_gemm_add/gemm_add_wmma_fp16.cpp index c98fc4b39e..9fc366b298 100644 --- a/example/68_gemm_add/gemm_add_wmma_fp16.cpp +++ b/example/68_gemm_add/gemm_add_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -71,6 +71,6 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_ // clang-format on -#include "run_gemm_add_example_v3.inc" +#include "run_gemm_add_example.inc" int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/68_gemm_add/gemm_add_xdl_bf16.cpp b/example/68_gemm_add/gemm_add_xdl_bf16.cpp index f5bfc14ebc..5d2cab49d2 100644 --- a/example/68_gemm_add/gemm_add_xdl_bf16.cpp +++ b/example/68_gemm_add/gemm_add_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/68_gemm_add/gemm_add_xdl_fp16.cpp b/example/68_gemm_add/gemm_add_xdl_fp16.cpp index fd86738260..1338caef8b 100644 --- a/example/68_gemm_add/gemm_add_xdl_fp16.cpp +++ b/example/68_gemm_add/gemm_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" From e6ea4aaf6d40f80337fdf2afc1906b547c54d55d Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 10:55:21 +0000 Subject: [PATCH 59/62] Fixed typo. --- .../ck/library/tensor_operation_instance/gpu/gemm_add.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp index 3832a9f6d4..bc012fa675 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -30,7 +30,7 @@ void add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( PassThrough, Add>>>&); -void add_device_gemm_add_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( +void add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( std::vector Date: Wed, 9 Jul 2025 11:27:26 +0000 Subject: [PATCH 60/62] Fixed copyrights. --- test/gemm_add/test_gemm_add_relu_wmma.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gemm_add/test_gemm_add_relu_wmma.cpp b/test/gemm_add/test_gemm_add_relu_wmma.cpp index 8ff2f4217b..76c66a11b1 100644 --- a/test/gemm_add/test_gemm_add_relu_wmma.cpp +++ b/test/gemm_add/test_gemm_add_relu_wmma.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" From 8e917555457107e49ab28b011f74ffa855dc7fa5 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 11:48:33 +0000 Subject: [PATCH 61/62] Updated copyrights. --- example/68_gemm_add/common.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/68_gemm_add/common.hpp b/example/68_gemm_add/common.hpp index eab37e4132..38e77a160f 100644 --- a/example/68_gemm_add/common.hpp +++ b/example/68_gemm_add/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once From aea158f79718f4f12dd2d5cb037e8e955e5a7fc8 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 9 Jul 2025 11:51:49 +0000 Subject: [PATCH 62/62] updated copyrights. --- test/gemm_add/test_gemm_add_wmma.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gemm_add/test_gemm_add_wmma.cpp b/test/gemm_add/test_gemm_add_wmma.cpp index 5fc6738ca9..ae08d50fcc 100644 --- a/test/gemm_add/test_gemm_add_wmma.cpp +++ b/test/gemm_add/test_gemm_add_wmma.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp"