mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge branch '52-implement-multipled-in-gemm-universal' into 'feature/multiple-d-gemms'
DeviceGemmMultipleD_Wmma_CShuffleV3 See merge request amd/ai/composable_kernel!21
This commit is contained in:
@@ -30,3 +30,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp)
|
||||
|
||||
267
example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp
Normal file
267
example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp
Normal file
@@ -0,0 +1,267 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F16;
|
||||
using B0DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
struct AddAdd
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
|
||||
ck::half_t& e, const float& c, const float& d0, const float& d1) const
|
||||
{
|
||||
const float x0_f = c + d0 + d1;
|
||||
|
||||
e = ck::type_convert<ck::half_t>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAdd;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//#########################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| S<C, D..>| | |
|
||||
< A0Layout, B0Layout, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideD = K;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideD = std::stoi(argv[9]);
|
||||
StrideE = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
|
||||
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
|
||||
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
|
||||
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
|
||||
break;
|
||||
default:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
d0_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_m_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
|
||||
d1_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
|
||||
StrideE,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
|
||||
sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
|
||||
B0DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -184,7 +184,6 @@ int main(int argc, char* argv[])
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
d0_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
@@ -220,11 +219,12 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
|
||||
sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -233,8 +233,6 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#ifndef __HIPCC_RTC__
|
||||
@@ -149,6 +149,107 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
|
||||
#endif
|
||||
};
|
||||
|
||||
/// @brief Wrapper for backward compatibility that allows to use instances of
|
||||
/// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected.
|
||||
///
|
||||
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
|
||||
/// The only difference between API of DeviceGemmMultipleD and DeviceGemmMultipleDSplitK is
|
||||
/// that DeviceGemmMultipleDSplitK::MakeArgumentPointer requires an additional parameter
|
||||
/// KBatch which is explicitly passed as 1 by this wrapper.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleDSplitKWrapper : public DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
|
||||
explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
|
||||
: p_op_(std::move(p_op))
|
||||
{
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return p_op_->IsSupportedArgument(p_arg);
|
||||
}
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return p_op_->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
1, // KBatch
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return p_op_->MakeInvokerPointer();
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceOp> p_op_;
|
||||
|
||||
#endif // __HIPCC_RTC__
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,613 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM operation implements the following mathematical equation:
|
||||
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
|
||||
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
|
||||
// operations that could be applied on each tensor respectively. The CDE_op is an
|
||||
// elementwise operation applied to the C and all D tensors.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through it's design
|
||||
/// and versatilty.
|
||||
///
|
||||
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
|
||||
/// to split the dot product accumulated over the K dimension into multiple working groups.
|
||||
/// The partial products of different workgroups are then reduced using the AtomicAdd
|
||||
/// operation.
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam DsLayout D tensors data layouts.
|
||||
/// @tparam ELayout E tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam DsDataType D tensors data types.
|
||||
/// @tparam EDataType E tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
|
||||
/// GEMM) and D input tensors.
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
/// @tparam NPerBlock The input/output data tile size in the N dimension.
|
||||
/// @tparam KPerBlock The input data tile size in the K dimension.
|
||||
/// @tparam AK1 The vector load size from global memory for A tensor.
|
||||
/// @tparam BK1 The vector load size from global memory for B tensor.
|
||||
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
|
||||
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
|
||||
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
|
||||
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question, "How many threads can be
|
||||
/// arranged on each input data axis?"
|
||||
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
|
||||
/// data. Can be interpreted as the answer
|
||||
/// to the question: "How many threads to
|
||||
/// arrange on each input data axis?"
|
||||
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
|
||||
/// the input tensor dimension. Can be interpreted
|
||||
/// as the answer to the question: "In which
|
||||
/// order to spread threads through tensor axes?".
|
||||
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
|
||||
/// interpreted as the answer to the question "Which dimension
|
||||
/// to read first? And which next?" etc.
|
||||
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
|
||||
/// access - the one with contiguous memory.
|
||||
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
|
||||
/// elements accessed per thread per instruction.
|
||||
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
|
||||
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
|
||||
/// universal GEMM there's no need for padding.
|
||||
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in M dimension.
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
|
||||
/// Used when loading data from D tensors and storing data
|
||||
/// to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
|
||||
/// instructions.
|
||||
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory. Currently not supported!
|
||||
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
|
||||
/// in global memory (pre-shuffled).
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGemmMultipleD_Wmma_CShuffleV3
|
||||
: public DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
/// @brief Helper structure responsible for kernel invocation.
|
||||
///
|
||||
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
|
||||
/// kernel function. It usually determines the launched grid size prepares kernel
|
||||
/// arguments as well as perform specific kernel configuration selection based on
|
||||
/// runtime arguments.
|
||||
///
|
||||
/// @note If appropriately configured it may measure kernel execution time.
|
||||
///
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
std::array<std::size_t, NumDTensor> size_ds_buffers;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
size_ds_buffers[i] =
|
||||
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_,
|
||||
stream_config.rotating_count,
|
||||
size_a_buffer,
|
||||
size_b_buffer,
|
||||
size_ds_buffers);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: Implement
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
|
||||
std::is_same_v<EDataType, ck::bhalf_t>)
|
||||
{
|
||||
if(arg.KBatch > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideE,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<EDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMultipleD_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0];
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
str << std::string(DLayout::name)[0];
|
||||
});
|
||||
str << std::string(ELayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma << "x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat << "x" << NRepeat << ", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
REGISTER_EXTRA_PRINTING_METHODS
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -176,15 +176,16 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>, // DsDataType
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
@@ -219,7 +220,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
@@ -294,7 +295,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid,
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
@@ -312,7 +313,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid,
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
@@ -468,11 +469,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -488,20 +503,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
KBatch);
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -19,7 +19,7 @@ namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
@@ -31,22 +31,26 @@ __global__ void
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
@@ -59,9 +63,10 @@ __global__ void
|
||||
///
|
||||
/// @par Overview
|
||||
/// This GEMM kernel is carrying out following mathematical equation:
|
||||
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
|
||||
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
|
||||
/// elementwise operations that could be applied on each tensor respectively.
|
||||
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
|
||||
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
|
||||
// operations that could be applied on each tensor respectively. The CDE_op is an
|
||||
// elementwise operation applied to the C and all D tensors.
|
||||
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
|
||||
/// scenarios. That's why it's called \"universal\". It's universal through it's design
|
||||
/// and versatilty.
|
||||
@@ -73,18 +78,20 @@ __global__ void
|
||||
///
|
||||
/// @tparam ALayout A tensor data layout.
|
||||
/// @tparam BLayout B tensor data layout.
|
||||
/// @tparam CLayout C tensor data layout.
|
||||
/// @tparam DsLayout D tensors data layouts.
|
||||
/// @tparam ELayout E tensor data layout.
|
||||
/// @tparam ADataType A tensor data type.
|
||||
/// @tparam BDataType B tensor data type.
|
||||
/// @tparam AccDataType The accumulation data type related to the hardware
|
||||
/// matrix-multiplication instruction.
|
||||
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
|
||||
/// LDS memory during \"CShuffle\" data layout optimization.
|
||||
/// @tparam CDataType C tensor data type.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
|
||||
/// (after GEMM).
|
||||
/// @tparam DsDataType D tensors data types.
|
||||
/// @tparam EDataType E tensor data type.
|
||||
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
|
||||
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
|
||||
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
|
||||
/// GEMM) and D input tensors.
|
||||
/// @tparam GemmSpec Determines used "padding" version.
|
||||
/// @tparam BlockSize The number of threads within workgroup.
|
||||
/// @tparam MPerBlock The input/output data tile size in the M dimension.
|
||||
@@ -142,11 +149,12 @@ __global__ void
|
||||
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
|
||||
/// results to process per wave per iteration of CShuffle
|
||||
/// in N dimension.
|
||||
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
|
||||
/// thread distribution used for storing data into output
|
||||
/// tensor across output data layout dimensions.
|
||||
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
|
||||
/// Used when storing data to output tensor.
|
||||
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
|
||||
/// Used when loading data from D tensors and storing data
|
||||
/// to output tensor.
|
||||
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
|
||||
/// intrawave).
|
||||
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
|
||||
@@ -160,15 +168,17 @@ __global__ void
|
||||
/// in global memory (pre-shuffled).
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
@@ -198,8 +208,8 @@ template <typename ALayout,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
@@ -217,6 +227,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
static constexpr auto EShuffleBlockTransferScalarPerVector =
|
||||
CDEShuffleBlockTransferScalarPerVectors{}[I0];
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
|
||||
@@ -530,17 +543,18 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
|
||||
}
|
||||
|
||||
template <typename DELayout>
|
||||
__host__ __device__ static auto
|
||||
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
|
||||
MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
|
||||
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -593,6 +607,44 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto MakeDsGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
return static_cast<const DDataType*>(nullptr);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
|
||||
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
return MakeDEGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
template <typename DsGridDesc>
|
||||
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n[i], MBlock, NBlock);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem(index_t M_,
|
||||
@@ -600,14 +652,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideC{StrideC_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
@@ -627,8 +681,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SC:" << StrideC << ", "
|
||||
<< "SB:" << StrideB << ", ";
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
std::cout << "SDs: { ";
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
|
||||
});
|
||||
std::cout << " }, ";
|
||||
}
|
||||
std::cout << "SE:" << StrideE << ", "
|
||||
<< "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", "
|
||||
<< "KRead:" << KRead << ", "
|
||||
@@ -644,7 +706,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
index_t StrideC;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
@@ -661,21 +724,35 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
{
|
||||
__host__ Argument(const ADataType* p_a_grid_,
|
||||
const BDataType* p_b_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
EDataType* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_,
|
||||
bool is_reduce_ = false)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
cde_element_op{cde_element_op_},
|
||||
is_reduce(is_reduce_)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ inline bool IsReduceAdd() const
|
||||
@@ -690,42 +767,49 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
const ADataType* p_a_grid;
|
||||
const BDataType* p_b_grid;
|
||||
CDataType* p_c_grid;
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
const CDEElementwiseOperation cde_element_op;
|
||||
|
||||
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
|
||||
bool is_reduce;
|
||||
};
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
@@ -736,7 +820,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
if(karg.IsReduceAdd())
|
||||
{
|
||||
c_reduce_offset = blockIdx.z * karg.M * karg.N;
|
||||
c_reduce_offset = k_id * karg.M * karg.N;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1143,7 +1227,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * KPerBlock! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
@@ -1219,56 +1303,51 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
if(karg.N % EShuffleBlockTransferScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
"EShuffleBlockTransferScalarPerVector ("
|
||||
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
if(karg.M % EShuffleBlockTransferScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
"EShuffleBlockTransferScalarPerVector ("
|
||||
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, float>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
|
||||
if constexpr(!(is_same<remove_cvref_t<EDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, float>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, bhalf_t>::value ||
|
||||
is_same<remove_cvref_t<EDataType>, int32_t>::value))
|
||||
{
|
||||
if(!karg.IsReduceAdd())
|
||||
if(karg.IsAtomicAdd() && karg.KBatch > 1)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet"
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this "
|
||||
<< "destination type (EDataType) " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1301,18 +1380,18 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
|
||||
}
|
||||
|
||||
template <typename CGridDesc>
|
||||
__host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
template <typename DEGridDesc>
|
||||
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
{
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
de_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
return de_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
@@ -1322,30 +1401,40 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock)
|
||||
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
const AElementwiseOperation a_element_op{};
|
||||
const BElementwiseOperation b_element_op{};
|
||||
const CElementwiseOperation c_element_op{};
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_ds_grid[i],
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
@@ -1355,8 +1444,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -1483,7 +1572,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
// Epilogue: shuffle C for better memory access pattern, apply elementwise operation to
|
||||
// C and Ds, write out result E to global memory
|
||||
{
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
@@ -1601,31 +1691,63 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
m_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c_ds_desc_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor buffers
|
||||
const auto c_ds_buf_refs = concat_tuple_of_reference(
|
||||
tie(c_shuffle_block_buf),
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& // return type should be reference
|
||||
{ return ds_grid_buf[i]; },
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c_ds_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
|
||||
},
|
||||
Number<NumDTensor>{}));
|
||||
|
||||
// blockwise copy which loads C from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
|
||||
Tuple<EDataType>,
|
||||
decltype(c_ds_desc_refs),
|
||||
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation, // ElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
|
||||
c_element_op};
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
|
||||
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
|
||||
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
|
||||
3, // SrcVectorDim,
|
||||
3, // DstVectorDim,
|
||||
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
|
||||
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumDTensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
{c_ds_desc_refs,
|
||||
idx_c_ds_block_begin,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
|
||||
cde_element_op};
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
@@ -1641,7 +1763,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
MAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
constexpr auto sfc_cde_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
@@ -1651,7 +1773,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
|
||||
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
@@ -1668,57 +1790,78 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global.Run(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
// each block loads its C data from LDS, D from global, applies elementwise
|
||||
// operation and stores result E to global
|
||||
cde_shuffle_block_copy_lds_and_global.Run(
|
||||
c_ds_desc_refs,
|
||||
c_ds_buf_refs,
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
|
||||
// move on Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
|
||||
c_ds_desc_refs, i + I1, cde_global_step);
|
||||
});
|
||||
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
// move on E
|
||||
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
|
||||
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum = TailNumber::Odd>
|
||||
__device__ static void Run(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = MakeDEGridDescriptor_M_N<ELayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
Run<decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
problem,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock);
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -93,8 +94,90 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>&);
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>&);
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
// GEMM + Add + FastGelu
|
||||
// DeviceGemmMultipleDSplitK specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add other types and layouts
|
||||
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Add + FastGelu
|
||||
// DeviceGemmMultipleD specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
@@ -132,7 +215,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
|
||||
#if defined(CK_USE_XDL)
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -143,7 +227,7 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP16 && CK_ENABLE_INT8
|
||||
|
||||
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
|
||||
@@ -156,8 +240,9 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_ENABLE_BF16 && CK_ENABLE_INT8
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -186,6 +271,29 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_add_fastgelu_instance
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
// e = elementwise((a * b), d0)
|
||||
// elementwise(c, d0) = fastgelu(c + d0)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k], d0[m, n]
|
||||
|
||||
using device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmDefault, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_generic_instances{});
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_add_fastgelu_wmma_universal_f16_f16_f16_f16_mk_nk_mn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -45,7 +45,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_streamk.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_fastgelu.cpp)
|
||||
@@ -86,11 +85,14 @@ if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFIN
|
||||
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
|
||||
endif()
|
||||
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
|
||||
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_fastgelu.cpp)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
@@ -152,7 +154,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_silu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_relu_add_layernorm_instance)
|
||||
@@ -202,6 +203,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_fastgelu_instance)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp)
|
||||
add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance)
|
||||
target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp)
|
||||
add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
|
||||
target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp)
|
||||
add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
|
||||
target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp)
|
||||
add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
|
||||
target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance)
|
||||
endif()
|
||||
|
||||
40
test/gemm_add/test_gemm_add_fastgelu_wmma.cpp
Normal file
40
test/gemm_add/test_gemm_add_fastgelu_wmma.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_fastgelu_impl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddFastgelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using EDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ALayout = std::tuple_element_t<5, Tuple>;
|
||||
using BLayout = std::tuple_element_t<6, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using ELayout = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
constexpr static auto ProfileGemmAddFastgeluImpl =
|
||||
ck::profiler::profile_gemm_add_fastgelu_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
ELayout>;
|
||||
|
||||
decltype(ProfileGemmAddFastgeluImpl) GetImpl() override { return ProfileGemmAddFastgeluImpl; }
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, Row, Col, Row, Row>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes);
|
||||
TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16) { this->Run(); }
|
||||
@@ -1,13 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_fastgelu_impl.hpp"
|
||||
#include "test_gemm_add_xdl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddFastgelu : public TestGemmAdd<Tuple>
|
||||
class TestGemmAddFastgelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_relu_impl.hpp"
|
||||
#include "test_gemm_add_xdl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddRelu : public TestGemmAdd<Tuple>
|
||||
class TestGemmAddRelu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_silu_impl.hpp"
|
||||
#include "test_gemm_add_xdl.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAddSilu : public TestGemmAdd<Tuple>
|
||||
class TestGemmAddSilu : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_impl.hpp"
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using I8 = int8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmAdd : public ::testing::Test
|
||||
class TestGemmAdd : public TestGemmD0Common<Tuple>
|
||||
{
|
||||
protected:
|
||||
private:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
@@ -37,32 +30,7 @@ class TestGemmAdd : public ::testing::Test
|
||||
D0Layout,
|
||||
ELayout>;
|
||||
|
||||
virtual decltype(ProfileGemmAddImpl) GetImpl() { return ProfileGemmAddImpl; }
|
||||
|
||||
void Run()
|
||||
{
|
||||
std::vector<std::vector<ck::index_t>> lengths = {
|
||||
{16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}};
|
||||
|
||||
bool all_success = true;
|
||||
|
||||
for(auto length : lengths)
|
||||
{
|
||||
int M = length[0];
|
||||
int N = length[1];
|
||||
int K = length[2];
|
||||
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
|
||||
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
decltype(ProfileGemmAddImpl) GetImpl() override { return ProfileGemmAddImpl; }
|
||||
};
|
||||
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
|
||||
|
||||
66
test/gemm_add/test_gemm_common.hpp
Normal file
66
test/gemm_add/test_gemm_common.hpp
Normal file
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_add_impl.hpp"
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using I8 = int8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmD0Common : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<2, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<3, Tuple>;
|
||||
using EDataType = std::tuple_element_t<4, Tuple>;
|
||||
using ALayout = std::tuple_element_t<5, Tuple>;
|
||||
using BLayout = std::tuple_element_t<6, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using ELayout = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
D0DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
ELayout>;
|
||||
|
||||
virtual decltype(ProfileGemmAddImpl) GetImpl() = 0;
|
||||
|
||||
void Run()
|
||||
{
|
||||
std::vector<std::vector<ck::index_t>> lengths = {
|
||||
{16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}};
|
||||
|
||||
bool all_success = true;
|
||||
|
||||
for(auto length : lengths)
|
||||
{
|
||||
int M = length[0];
|
||||
int N = length[1];
|
||||
int K = length[2];
|
||||
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
|
||||
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user