Merge branch '52-implement-multipled-in-gemm-universal' into 'feature/multiple-d-gemms'

DeviceGemmMultipleD_Wmma_CShuffleV3

See merge request amd/ai/composable_kernel!21
This commit is contained in:
Anton Gorenko
2025-06-04 16:18:14 +05:00
18 changed files with 1638 additions and 228 deletions

View File

@@ -30,3 +30,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp)

View File

@@ -0,0 +1,267 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F16;
using B0DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using A0Layout = Row;
using B0Layout = Col;
using D0Layout = Row;
using D1Layout = Row;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
struct AddAdd
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c + d0 + d1;
e = ck::type_convert<ck::half_t>(x0_f);
}
};
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3
// clang-format off
//#########################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//#########################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| S<C, D..>| | |
< A0Layout, B0Layout, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = K;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -184,7 +184,6 @@ int main(int argc, char* argv[])
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
@@ -220,11 +219,12 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -233,8 +233,6 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
@@ -149,6 +149,107 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
#endif
};
/// @brief Wrapper for backward compatibility that allows to use instances of
/// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected.
///
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
/// The only difference between API of DeviceGemmMultipleD and DeviceGemmMultipleDSplitK is
/// that DeviceGemmMultipleDSplitK::MakeArgumentPointer requires an additional parameter
/// KBatch which is explicitly passed as 1 by this wrapper.
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitKWrapper : public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
: p_op_(std::move(p_op))
{
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return p_op_->IsSupportedArgument(p_arg);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return p_op_->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
1, // KBatch
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return p_op_->MakeInvokerPointer();
}
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
private:
std::unique_ptr<DeviceOp> p_op_;
#endif // __HIPCC_RTC__
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,613 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors.
///
/// @par Overview
/// This GEMM operation implements the following mathematical equation:
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
///
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
/// to split the dot product accumulated over the K dimension into multiple working groups.
/// The partial products of different workgroups are then reduced using the AtomicAdd
/// operation.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
/// GEMM) and D input tensors.
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1 The vector load size from global memory for A tensor.
/// @tparam BK1 The vector load size from global memory for B tensor.
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
/// Used when loading data from D tensors and storing data
/// to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmMultipleD_Wmma_CShuffleV3
: public DeviceGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using Argument = typename GridwiseGemm::Argument;
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
/// @note If appropriately configured it may measure kernel execution time.
///
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
std::array<std::size_t, NumDTensor> size_ds_buffers;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
size_ds_buffers[i] =
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_,
stream_config.rotating_count,
size_a_buffer,
size_b_buffer,
size_ds_buffers);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
0,
arg_.M * arg_.N * sizeof(EDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
0,
arg.M * arg.N * sizeof(EDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
{
// TODO: Implement
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return false;
}
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
std::is_same_v<EDataType, ck::bhalf_t>)
{
if(arg.KBatch > 1 && ck::is_gfx11_supported())
{
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
return false;
}
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMultipleD_Wmma_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0];
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
str << std::string(DLayout::name)[0];
});
str << std::string(ELayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma << "x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat << "x" << NRepeat << ", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
REGISTER_EXTRA_PRINTING_METHODS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -176,15 +176,16 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
BElementwiseOperation,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
CDataType,
AElementwiseOperation,
BElementwiseOperation,
@@ -219,7 +220,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -294,7 +295,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid,
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
@@ -312,7 +313,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
else
{
if(arg.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid,
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
@@ -468,11 +469,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
index_t StrideB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
return Argument{p_a,
p_b,
std::array<const void*, 0>{}, // p_ds_grid_
p_c,
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -488,20 +503,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
index_t StrideB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
std::array<const void*, 0>{}, // p_ds_grid_
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
KBatch);
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic

View File

@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -19,7 +19,7 @@ namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
@@ -31,22 +31,26 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
#if defined(__gfx11__)
}
#endif
@@ -59,9 +63,10 @@ __global__ void
///
/// @par Overview
/// This GEMM kernel is carrying out following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations that could be applied on each tensor respectively.
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
@@ -73,18 +78,20 @@ __global__ void
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam CDataType C tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
/// GEMM) and D input tensors.
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
@@ -142,11 +149,12 @@ __global__ void
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
/// Used when loading data from D tensors and storing data
/// to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
@@ -160,15 +168,17 @@ __global__ void
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
@@ -198,8 +208,8 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
@@ -217,6 +227,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto EShuffleBlockTransferScalarPerVector =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
@@ -530,17 +543,18 @@ struct GridwiseGemm_wmma_cshuffle_v3
return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
}
template <typename DELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE));
}
}();
@@ -593,6 +607,44 @@ struct GridwiseGemm_wmma_cshuffle_v3
#endif
}
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeDEGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
},
Number<NumDTensor>{});
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
return generate_tuple(
[&](auto i) {
return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[i], MBlock, NBlock);
},
Number<NumDTensor>{});
}
struct Problem
{
__host__ Problem(index_t M_,
@@ -600,14 +652,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
StrideDs{StrideDs_},
StrideE{StrideE_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
@@ -627,8 +681,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "SB:" << StrideB << ", ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
@@ -644,7 +706,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t KBatch;
index_t MPadded;
index_t NPadded;
@@ -661,21 +724,35 @@ struct GridwiseGemm_wmma_cshuffle_v3
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
p_ds_grid{},
p_e_grid{p_e_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}
__host__ __device__ inline bool IsReduceAdd() const
@@ -690,42 +767,49 @@ struct GridwiseGemm_wmma_cshuffle_v3
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CDEElementwiseOperation cde_element_op;
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
bool is_reduce;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
b_k_split_offset = k_id * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
b_k_split_offset = k_id * k0_offset / BPackedSize;
}
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
@@ -736,7 +820,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
c_reduce_offset = k_id * karg.M * karg.N;
}
else
{
@@ -1143,7 +1227,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
std::cout << "Arg K value is not a multiple of K_Batch * KPerBlock! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
@@ -1219,56 +1303,51 @@ struct GridwiseGemm_wmma_cshuffle_v3
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.N % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
else
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.M % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
if constexpr(!(is_same<remove_cvref_t<EDataType>, half_t>::value ||
is_same<remove_cvref_t<EDataType>, float>::value ||
is_same<remove_cvref_t<EDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<EDataType>, int32_t>::value))
{
if(!karg.IsReduceAdd())
if(karg.IsAtomicAdd() && karg.KBatch > 1)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
if(karg.KBatch > 1)
{
return false;
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this "
<< "destination type (EDataType) " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
}
@@ -1301,18 +1380,18 @@ struct GridwiseGemm_wmma_cshuffle_v3
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
}
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
de_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
return de_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
@@ -1322,30 +1401,40 @@ struct GridwiseGemm_wmma_cshuffle_v3
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock)
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
@@ -1355,8 +1444,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
@@ -1483,7 +1572,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
// Epilogue: shuffle C for better memory access pattern, apply elementwise operation to
// C and Ds, write out result E to global memory
{
// C mapping in single thread.
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
@@ -1601,31 +1691,63 @@ struct GridwiseGemm_wmma_cshuffle_v3
m_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// blockwise copy which loads C from LDS, D from global, applies elementwise
// operation and stores result E to global
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, // ThreadGroup
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
3, // SrcVectorDim,
3, // DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
@@ -1641,7 +1763,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
MAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
constexpr auto sfc_cde_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
@@ -1651,7 +1773,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
@@ -1668,57 +1790,78 @@ struct GridwiseGemm_wmma_cshuffle_v3
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
});
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
}
});
}
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem)
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = MakeDEGridDescriptor_M_N<ELayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
p_c_grid,
p_ds_grid,
p_e_grid,
p_shared,
problem,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock);
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
cde_element_op);
}
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -93,8 +94,90 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_in
PassThrough,
PassThrough,
AddFastGelu>>>&);
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
void add_device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
#endif // CK_USE_WMMA
// GEMM + Add + FastGelu
// DeviceGemmMultipleDSplitK specialization
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances(
op_ptrs);
}
}
// TODO: Add other types and layouts
#endif // CK_ENABLE_FP16
#endif // CK_USE_WMMA
return op_ptrs;
}
};
// GEMM + Add + FastGelu
// DeviceGemmMultipleD specialization
template <typename ALayout,
typename BLayout,
typename D0Layout,
@@ -132,7 +215,8 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
#if defined(CK_USE_XDL)
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
@@ -143,7 +227,7 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
#endif // CK_ENABLE_FP16 && CK_ENABLE_INT8
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
@@ -156,8 +240,9 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
#endif // CK_ENABLE_BF16 && CK_ENABLE_INT8
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
@@ -186,6 +271,29 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif // CK_ENABLE_FP16
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}

View File

@@ -1,9 +1,11 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_fastgelu_instance
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
// e = elementwise((a * b), d0)
// elementwise(c, d0) = fastgelu(c + d0)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n]
using device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
using device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmDefault, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
void add_device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances{});
add_device_operation_instances(
instances, device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -45,7 +45,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
@@ -86,11 +85,14 @@ if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFIN
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
endif()
endif()
if(DL_KERNELS)
@@ -152,7 +154,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
@@ -202,6 +203,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
endif()
endif()
if(DL_KERNELS)

View File

@@ -1,19 +1,24 @@
add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp)
add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance)
target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance)
endif()
add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp)
add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
endif()
add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp)
add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
endif()
add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp)
add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
endif()
add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance)
endif()

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_fastgelu_impl.hpp"
#include "test_gemm_common.hpp"
template <typename Tuple>
class TestGemmAddFastgelu : public TestGemmD0Common<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddFastgeluImpl =
ck::profiler::profile_gemm_add_fastgelu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
decltype(ProfileGemmAddFastgeluImpl) GetImpl() override { return ProfileGemmAddFastgeluImpl; }
};
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Col, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes);
TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16) { this->Run(); }

View File

@@ -1,13 +1,13 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_fastgelu_impl.hpp"
#include "test_gemm_add_xdl.hpp"
#include "test_gemm_common.hpp"
template <typename Tuple>
class TestGemmAddFastgelu : public TestGemmAdd<Tuple>
class TestGemmAddFastgelu : public TestGemmD0Common<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;

View File

@@ -1,13 +1,13 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_relu_impl.hpp"
#include "test_gemm_add_xdl.hpp"
#include "test_gemm_common.hpp"
template <typename Tuple>
class TestGemmAddRelu : public TestGemmAdd<Tuple>
class TestGemmAddRelu : public TestGemmD0Common<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;

View File

@@ -1,13 +1,13 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_silu_impl.hpp"
#include "test_gemm_add_xdl.hpp"
#include "test_gemm_common.hpp"
template <typename Tuple>
class TestGemmAddSilu : public TestGemmAdd<Tuple>
class TestGemmAddSilu : public TestGemmD0Common<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;

View File

@@ -1,22 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_impl.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using I8 = int8_t;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
#include "test_gemm_common.hpp"
template <typename Tuple>
class TestGemmAdd : public ::testing::Test
class TestGemmAdd : public TestGemmD0Common<Tuple>
{
protected:
private:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
@@ -37,32 +30,7 @@ class TestGemmAdd : public ::testing::Test
D0Layout,
ELayout>;
virtual decltype(ProfileGemmAddImpl) GetImpl() { return ProfileGemmAddImpl; }
void Run()
{
std::vector<std::vector<ck::index_t>> lengths = {
{16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}};
bool all_success = true;
for(auto length : lengths)
{
int M = length[0];
int N = length[1];
int K = length[2];
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
all_success =
all_success &
GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE);
}
EXPECT_TRUE(all_success);
}
decltype(ProfileGemmAddImpl) GetImpl() override { return ProfileGemmAddImpl; }
};
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,

View File

@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_impl.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using I8 = int8_t;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <typename Tuple>
class TestGemmD0Common : public ::testing::Test
{
protected:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
virtual decltype(ProfileGemmAddImpl) GetImpl() = 0;
void Run()
{
std::vector<std::vector<ck::index_t>> lengths = {
{16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}};
bool all_success = true;
for(auto length : lengths)
{
int M = length[0];
int N = length[1];
int K = length[2];
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
all_success =
all_success &
GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE);
}
EXPECT_TRUE(all_success);
}
};