mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Gemm + bias + relu + add + layernorm (#272)
* Copy "gemm reduce" to "gemm bias add reduce"
* Implement gemm bias add reduction
* Fix compiler error due to merge from develop
* Add tensor operation for gemm + bias + add + reduce
* Add gemm_bais_add_reduce to ckProfiler
* Add c1 functor
* Refine type
* Use reduceAccDataType instead of explicitly float
* Change to use check_err()
* Do relu in float32 instead of bhalf_t. Because bhalf_t is unsigned
* Refactor relu. using type_trait instead of overloading
* Rename DxsReduceAccElementwiseOperation to DxsReduceAccElementwiseOperation
* Fix denominator
* Refine nameing
* Fix denominator in host
* Remove useless include header
* Use AccDataType
* Fix static_cast order
* Refine type
* [What] Remove tuple type in the base class
[Why] External api depend on base class. if base class has relationship with type, we will need many class for different type
[ROCm/composable_kernel commit: 6eb5549923]
This commit is contained in:
@@ -51,8 +51,8 @@ using UnaryDivElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
@@ -67,7 +67,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
@@ -204,8 +204,8 @@ int main(int argc, char* argv[])
|
||||
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOp{};
|
||||
auto dxs_out_element_op = DxsOutElementOp{M, M};
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmReduceInstance{};
|
||||
@@ -261,14 +261,15 @@ int main(int argc, char* argv[])
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
float d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
float c_val = ck::type_convert<float>(c_m_n_host_result(m, n));
|
||||
float d0_val = 0;
|
||||
float d1_val = 0;
|
||||
ReduceAccDataType c_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
ReduceAccDataType d0_val = 0;
|
||||
ReduceAccDataType d1_val = 0;
|
||||
|
||||
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
|
||||
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
|
||||
|
||||
@@ -47,8 +47,8 @@ using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOp = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
|
||||
|
||||
using DGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
@@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
|
||||
@@ -206,8 +206,8 @@ int main(int argc, char* argv[])
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
DxsInElementOp{},
|
||||
DxsOutElementOp{},
|
||||
DxsInElementOps{},
|
||||
DxsOutElementOps{},
|
||||
BatchCount);
|
||||
|
||||
if(!batched_gemm.IsSupportedArgument(argument))
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_5ary_elementwise.hpp"
|
||||
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
#include "element_wise_reduce_operation.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using C0DataType = F32;
|
||||
using C1DataType = F16;
|
||||
using GemmAccDataType = F32;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
|
||||
using GammaDataType = F16;
|
||||
using BetaDataType = F16;
|
||||
using LayerNormOutDataType = F16;
|
||||
using NormalizeComputeDataType = F32;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::Relu;
|
||||
using C1ElementOp = PassThrough;
|
||||
using ReduceSumOp = ck::reduce::Add<ReduceAccDataType>;
|
||||
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
|
||||
|
||||
using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using UnaryDivElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DxsGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
static constexpr auto GemmSpecialization =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, C1ElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GemmAccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
|
||||
|
||||
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
|
||||
using DeviceNormalizeInstance =
|
||||
ck::tensor_operation::device::Device5AryElementwise<CDataType,
|
||||
DDataType,
|
||||
DDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType,
|
||||
NormalizeComputeDataType,
|
||||
NormalizeFunctor,
|
||||
2,
|
||||
8,
|
||||
8, // scalarPerVector: gemm_out
|
||||
1, // scalarPerVector: reduce_mean
|
||||
1, // scalarPerVector: reduce_mean_square
|
||||
8, // scalarPerVector: Gamma
|
||||
8, // scalarPerVector: Beta
|
||||
8>; // scalarPerVector: LayerNorm_out
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor2d =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CDataType,
|
||||
typename DDataType,
|
||||
typename AccDataType,
|
||||
typename C0DataType,
|
||||
typename C1DataType,
|
||||
typename A_functor,
|
||||
typename B_functor,
|
||||
typename C_functor,
|
||||
typename C1_functor>
|
||||
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<ADataType>& b_k_n,
|
||||
const Tensor<C0DataType>& bias_n,
|
||||
const Tensor<C1DataType>& c1_m_n,
|
||||
const Tensor<GammaDataType>& gamma_n,
|
||||
const Tensor<GammaDataType>& beta_n,
|
||||
A_functor a_element_op,
|
||||
B_functor b_element_op,
|
||||
C_functor c_element_op,
|
||||
C1_functor c1_element_op,
|
||||
int M,
|
||||
int N)
|
||||
{
|
||||
|
||||
int StrideC = N;
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
auto averageOpInst = UnaryDivElementOp{N};
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
// c = activation(c + bias) + c1_functor(c1)
|
||||
for(int m = 0; m < M; ++m)
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType acc =
|
||||
static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n));
|
||||
|
||||
AccDataType c1 = static_cast<AccDataType>(c1_m_n(m, n));
|
||||
|
||||
c_element_op(acc, acc);
|
||||
c1_element_op(c1, c1);
|
||||
acc += c1;
|
||||
c_m_n(m, n) = static_cast<CDataType>(acc);
|
||||
}
|
||||
|
||||
// reduce_mean and reduce_square_mean
|
||||
auto reduceSumOpInst = ReduceSumOp{};
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
AccDataType mean_acc = reduceSumOpInst.GetIdentityValue();
|
||||
AccDataType square_mean_acc = reduceSumOpInst.GetIdentityValue();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType c_val = ck::type_convert<AccDataType>(c_m_n(m, n));
|
||||
AccDataType square_c_val = 0;
|
||||
UnarySquareElementOp{}(square_c_val, c_val);
|
||||
|
||||
reduceSumOpInst(mean_acc, c_val);
|
||||
reduceSumOpInst(square_mean_acc, square_c_val);
|
||||
}
|
||||
|
||||
averageOpInst(mean_acc, mean_acc);
|
||||
averageOpInst(square_mean_acc, square_mean_acc);
|
||||
mean_m(m) = ck::type_convert<DDataType>(mean_acc);
|
||||
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc);
|
||||
}
|
||||
|
||||
// LayerNorm
|
||||
auto layerNormInst = NormalizeFunctor{};
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType out_acc = 0;
|
||||
layerNormInst(out_acc, c_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n));
|
||||
out_m_n(m, n) = static_cast<DDataType>(out_acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename C0DataType,
|
||||
typename C1DataType,
|
||||
typename DDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename NormalizeDataType>
|
||||
void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K)
|
||||
{
|
||||
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
|
||||
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N +
|
||||
sizeof(C1DataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M;
|
||||
|
||||
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M + sizeof(GammaDataType) * N +
|
||||
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
|
||||
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
|
||||
float normalize_gb_per_sec = normalize_num_byte / 1.E6 / normalize_time;
|
||||
|
||||
std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, "
|
||||
<< tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec
|
||||
<< " GB/s, " << std::endl;
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = 1024;
|
||||
ck::index_t StrideB = 1024;
|
||||
ck::index_t StrideC = 1024;
|
||||
ck::index_t StrideC1 = 1024;
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<C0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<C1DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<LayerNormOutDataType> layerNorm_m_n(
|
||||
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
|
||||
bias_n.GenerateTensorValue(GeneratorTensor_3<C0DataType>{-1, 1});
|
||||
c1_m_n.GenerateTensorValue(GeneratorTensor_3<C1DataType>{-5, 5});
|
||||
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
|
||||
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace());
|
||||
DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) *
|
||||
reduceMeanSquare_m.mDesc.GetElementSpace());
|
||||
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
|
||||
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
|
||||
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
|
||||
layerNorm_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
bias_device_buf.ToDevice(bias_n.mData.data());
|
||||
c1_device_buf.ToDevice(c1_m_n.mData.data());
|
||||
gamma_device_buf.ToDevice(gamma_n.mData.data());
|
||||
beta_device_buf.ToDevice(beta_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto c1_element_op = C1ElementOp{};
|
||||
auto dxs_global =
|
||||
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
|
||||
// Prepare GEMM, reduce_mean, reduce_mean_square
|
||||
auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
|
||||
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
|
||||
auto gemmReduce_argument =
|
||||
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()),
|
||||
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideC1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
|
||||
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
reduceMean_device_buf.SetZero();
|
||||
reduceMeanSquare_device_buf.SetZero();
|
||||
|
||||
// Prepare LayerNorm
|
||||
auto normalize = DeviceNormalizeInstance{};
|
||||
auto normalize_invoker = normalize.MakeInvoker();
|
||||
auto normalize_argument = normalize.MakeArgument(
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
|
||||
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
|
||||
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
|
||||
{M, N},
|
||||
{StrideC, 1},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{0, 1},
|
||||
{0, 1},
|
||||
{StrideC, 1},
|
||||
NormalizeFunctor{});
|
||||
|
||||
if(!normalize.IsSupportedArgument(normalize_argument))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"Device5AryElementwise instance, exiting!");
|
||||
}
|
||||
|
||||
// run kernel
|
||||
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false});
|
||||
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false});
|
||||
|
||||
bool pass = true;
|
||||
{
|
||||
// verification
|
||||
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
|
||||
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
|
||||
host_gemm_layernorm<CDataType, DDataType, ReduceAccDataType>(host_layerNorm_m_n,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
bias_n,
|
||||
c1_m_n,
|
||||
gamma_n,
|
||||
beta_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
M,
|
||||
N);
|
||||
|
||||
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
|
||||
pass &= ck::utils::check_err(layerNorm_m_n.mData,
|
||||
host_layerNorm_m_n.mData,
|
||||
"Error: Incorrect results layerNorm_m_n",
|
||||
1e-2,
|
||||
1e-2);
|
||||
}
|
||||
|
||||
{
|
||||
// evaluate kernel perf
|
||||
bool time_kernel = true;
|
||||
|
||||
float gemm_reduce_mean_reduce_square_mean_ave_time =
|
||||
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
|
||||
float normalize_ave_time =
|
||||
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
if(time_kernel)
|
||||
DumpGemmLayerNormPerf<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
C0DataType,
|
||||
C1DataType,
|
||||
DDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType>(
|
||||
gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -2,7 +2,6 @@
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
@@ -54,8 +53,8 @@ using UnaryDivElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
|
||||
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DxsGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
@@ -70,7 +69,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
@@ -143,7 +142,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
auto averageOpInst = UnaryDivElementOp{M};
|
||||
auto averageOpInst = UnaryDivElementOp{N};
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
@@ -162,7 +161,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ReduceAccDataType c_val = ck::type_convert<float>(c_m_n(m, n));
|
||||
ReduceAccDataType c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n));
|
||||
ReduceAccDataType square_c_val = 0;
|
||||
UnarySquareElementOp{}(square_c_val, c_val);
|
||||
|
||||
@@ -267,8 +266,8 @@ int main()
|
||||
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOp{};
|
||||
auto dxs_out_element_op = DxsOutElementOp{M, M};
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
|
||||
// Prepare GEMM, reduce_mean, reduce_mean_square
|
||||
auto gemmReduce = DeviceGemmReduceInstance{};
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "gridwise_5ary_Elementwise_1d.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsAccElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -44,7 +44,7 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsAccElementwiseOperation dxs_out_element_op,
|
||||
const DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -126,7 +126,7 @@ template <typename ALayout,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsAccElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
@@ -162,12 +162,12 @@ template <typename ALayout,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsAccElementwiseOperation>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
: public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -527,7 +527,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
|
||||
CElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsAccElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
@@ -587,7 +587,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsAccElementwiseOperation dxs_out_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
@@ -645,7 +645,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsAccElementwiseOperation dxs_out_element_op_;
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -703,7 +703,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsAccElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -746,7 +746,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsAccElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -832,7 +832,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsAccElementwiseOperation dxs_out_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
{
|
||||
return Argument{p_a,
|
||||
@@ -856,27 +856,29 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount) override
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount) override
|
||||
{
|
||||
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
p_dxs,
|
||||
dxs_tuple,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
|
||||
@@ -0,0 +1,813 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_gemm_reduce.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
|
||||
// version currently has compiler issues with register spill which further causes validation
|
||||
// failures.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename C0DataType,
|
||||
typename C1DataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ReduceAccDataType,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
: public DeviceGemmBiasAddReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
}
|
||||
|
||||
// assume D is packed tensor
|
||||
static auto MakeDGridDescriptor_M(index_t MRaw)
|
||||
{
|
||||
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto MPad = M - MRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M
|
||||
return transform_tensor_descriptor(d_grid_desc_mraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M
|
||||
return d_grid_desc_mraw;
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0));
|
||||
using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
C0DataType,
|
||||
C1DataType,
|
||||
ReduceAccDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
C0GridDesc_M_N,
|
||||
C1GridDesc_M_N,
|
||||
DGridDesc_M,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
const C0DataType* p_c0_grid,
|
||||
const C1DataType* p_c1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t StrideC1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
C1ElementwiseOperation c1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_c0_grid_{p_c0_grid},
|
||||
p_c1_grid_{p_c1_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)},
|
||||
c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)},
|
||||
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d_grid_desc_mblock_mperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
c1_element_op_{c1_element_op},
|
||||
dxs_in_element_op_{dxs_in_element_op},
|
||||
dxs_out_element_op_{dxs_out_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c0_grid_desc_m_n_);
|
||||
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c1_grid_desc_m_n_);
|
||||
|
||||
d_grid_desc_mblock_mperblock_ =
|
||||
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
const C0DataType* p_c0_grid_;
|
||||
const C1DataType* p_c1_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
C0GridDesc_M_N c0_grid_desc_m_n_;
|
||||
C1GridDesc_M_N c1_grid_desc_m_n_;
|
||||
DGridDesc_M d_grid_desc_m_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
C1ElementwiseOperation c1_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
float elapsed_time = 0.0f;
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
C0DataType,
|
||||
C1DataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
|
||||
elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.p_c1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.c1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
C0DataType,
|
||||
C1DataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
|
||||
elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.p_c1_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.c1_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const C0DataType* p_c0,
|
||||
const C1DataType* p_c1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t StrideC1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
C1ElementwiseOperation c1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_c0,
|
||||
p_c1,
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideC1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const void* p_c0,
|
||||
const void* p_c1,
|
||||
void* p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t StrideC1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
C1ElementwiseOperation c1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
static_cast<const C0DataType*>(p_c0),
|
||||
static_cast<const C1DataType*>(p_c1),
|
||||
dxs_tuple,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideC1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmReduce_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -6,19 +6,18 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsAccElementwiseOperation>
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
struct DeviceGemmReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
DPtrsGlobal p_dxs,
|
||||
void* p_dxs,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
@@ -29,24 +28,69 @@ struct DeviceGemmReduce : public BaseOperator
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsAccElementwiseOperation dxs_out_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsAccElementwiseOperation>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsAccElementwiseOperation>>;
|
||||
DxsReduceAccElementwiseOperation>>;
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
struct DeviceGemmBiasAddReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const void* p_c0,
|
||||
const void* p_c1,
|
||||
void* p_dxs,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t StrideC1,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
C1ElementwiseOperation c1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
ck::index_t BatchCount = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
using DeviceGemmBiasAddReducePtr =
|
||||
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
C1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -32,7 +32,7 @@ template <typename ALayout,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsAccElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
@@ -68,12 +68,11 @@ template <typename ALayout,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsAccElementwiseOperation>
|
||||
DxsReduceAccElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -389,7 +388,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
CElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsAccElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
DGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
@@ -449,7 +448,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsAccElementwiseOperation dxs_out_element_op)
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
@@ -498,7 +497,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsAccElementwiseOperation dxs_out_element_op_;
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -554,7 +553,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsAccElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -594,7 +593,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsAccElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -669,7 +668,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsAccElementwiseOperation dxs_out_element_op)
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -691,27 +690,29 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsAccElementwiseOperation dxs_out_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
p_dxs,
|
||||
dxs_tuple,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
|
||||
@@ -144,6 +144,27 @@ struct AddHardswishAdd
|
||||
}
|
||||
};
|
||||
|
||||
struct Relu
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
|
||||
is_same<T, int8_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
y = x > 0 ? x : 0;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
|
||||
{
|
||||
float x_f32 = ck::type_convert<float>(x);
|
||||
float y_f32 = x_f32 > 0 ? x_f32 : 0;
|
||||
y = ck::type_convert<bhalf_t>(y_f32);
|
||||
}
|
||||
};
|
||||
|
||||
struct Normalize
|
||||
{
|
||||
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
@@ -0,0 +1,988 @@
|
||||
#pragma once
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "reduction_functions_threadwise.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename FloatC0,
|
||||
typename FloatC1,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename DGridDescriptor_MBlock_MPerBlock,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_bias_add_reduce_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const FloatC0* __restrict__ p_c0_grid,
|
||||
const FloatC1* __restrict__ p_c1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const C1ElementwiseOperation c1_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_c0_grid,
|
||||
p_c1_grid,
|
||||
p_ds_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_c0_grid;
|
||||
ignore = p_c1_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = c1_element_op;
|
||||
ignore = dxs_in_element_op;
|
||||
ignore = dxs_out_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = d_grid_desc_mblock_mperblock;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
template <typename FloatAB,
|
||||
typename FloatGemmAcc,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
typename FloatC0,
|
||||
typename FloatC1,
|
||||
typename FloatReduceAcc,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename C1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDesc_M_N,
|
||||
typename C0GridDesc_M_N,
|
||||
typename C1GridDesc_M_N,
|
||||
typename DGridDesc_M,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1Value,
|
||||
index_t BK1Value,
|
||||
index_t MPerXdl,
|
||||
index_t NPerXdl,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched>
|
||||
struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK1 = Number<BK1Value>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0, Number<MPerBlock>{}, AK1),
|
||||
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
{
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0, Number<NPerBlock>{}, BK1),
|
||||
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
|
||||
I1,
|
||||
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
|
||||
|
||||
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1, BK1);
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(FloatAB),
|
||||
c_block_size * sizeof(FloatCShuffle));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
|
||||
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
|
||||
// "wrong! K1 need to be known at compile-time");
|
||||
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
|
||||
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
|
||||
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
|
||||
return false;
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
return false;
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N_>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m)
|
||||
{
|
||||
const auto M = d_grid_desc_m.GetLength(I0);
|
||||
const auto MBlock = M / MPerBlock;
|
||||
|
||||
const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor(
|
||||
d_grid_desc_m,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1>{}));
|
||||
|
||||
return d_grid_desc_mblock_mperblock;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
|
||||
using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C0GridDesc_M_N{}))>;
|
||||
|
||||
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
|
||||
|
||||
using DGridDescriptor_MBlock_MPerBlock =
|
||||
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const FloatC0* __restrict__ p_c0_grid,
|
||||
const FloatC1* __restrict__ p_c1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const C1ElementwiseOperation& c1_element_op,
|
||||
const DxsInElementwiseOperation& dxs_in_element_op,
|
||||
const DxsReduceAccElementwiseOperation& dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c0_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c1_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(AK1, BK1);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
AElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<AK0, MPerBlock, AK1>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumGemmKPrefetchStage>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<BK0, NPerBlock, BK1>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumGemmKPrefetchStage>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatGemmAcc,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const auto gridwise_gemm_pipeline =
|
||||
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C + reduction + write out
|
||||
{
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatCShuffle*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
|
||||
FloatCShuffle,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
|
||||
|
||||
// TODO: this should be implemented as a blockwise reduction
|
||||
// LDS c_reduce_block_desc_mperblock_nperblock
|
||||
constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_pass_through_transform(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
|
||||
make_freeze_transform(I0),
|
||||
make_pass_through_transform(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{}));
|
||||
|
||||
static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) *
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
|
||||
BlockSize,
|
||||
"wrong!");
|
||||
|
||||
static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) ==
|
||||
0 &&
|
||||
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t mreduce_per_thread =
|
||||
(CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0);
|
||||
|
||||
constexpr index_t nreduce_per_thread =
|
||||
(CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1);
|
||||
|
||||
constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
|
||||
Sequence<mreduce_per_thread, nreduce_per_thread>{};
|
||||
|
||||
// VGPR c_reduce_thread_desc_mperblock_nperblock
|
||||
constexpr auto c_reduce_thread_desc_mperblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<mreduce_per_thread>{}, Number<nreduce_per_thread>{}));
|
||||
|
||||
// VGPR d_reduce_thread_desc_mperblock
|
||||
constexpr auto d_reduce_thread_desc_mperblock =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<mreduce_per_thread>{}));
|
||||
|
||||
// VGPR d_reduce_thread_desc_mblock_mperblock
|
||||
constexpr auto d_reduce_thread_desc_mblock_mperblock =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
|
||||
|
||||
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
// reduce: threadwise copy from LDS to VGPR
|
||||
constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
|
||||
|
||||
const auto c_reduce_thread_cluster_idx =
|
||||
c_reduce_thread_cluster_desc.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto c_reduce_thread_data_idx_begin =
|
||||
c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
|
||||
|
||||
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
|
||||
FloatCShuffle,
|
||||
FloatReduceAcc,
|
||||
decltype(c_reduce_block_desc_mperblock_nperblock),
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(c_reduce_thread_lengths_mperblock_nperblock),
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
1,
|
||||
true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
|
||||
|
||||
auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple(
|
||||
[&](auto I) {
|
||||
auto p_d_grid = p_ds_grid[I];
|
||||
auto d_out_element_op = dxs_out_element_op[I];
|
||||
|
||||
return ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatReduceAcc,
|
||||
remove_pointer_t<decltype(p_d_grid)>,
|
||||
decltype(d_reduce_thread_desc_mblock_mperblock),
|
||||
decltype(d_grid_desc_mblock_mperblock),
|
||||
decltype(d_out_element_op),
|
||||
Sequence<1, mreduce_per_thread>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
DGlobalMemoryDataOperation::At(I),
|
||||
1,
|
||||
false>{d_grid_desc_mblock_mperblock,
|
||||
make_multi_index(block_work_idx[I0], // mblock
|
||||
c_reduce_thread_data_idx_begin[I0]), // mperblock
|
||||
d_out_element_op};
|
||||
},
|
||||
Number<p_ds_grid.Size()>{});
|
||||
|
||||
// c0 and c1
|
||||
constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
|
||||
|
||||
constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
auto c01_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
|
||||
FloatC0,
|
||||
FloatReduceAcc,
|
||||
decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
1,
|
||||
true>(
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(I0,
|
||||
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
|
||||
I0,
|
||||
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
|
||||
|
||||
auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
|
||||
FloatC1,
|
||||
FloatReduceAcc,
|
||||
decltype(c1_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
1,
|
||||
true>(
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(I0,
|
||||
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
|
||||
I0,
|
||||
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
|
||||
|
||||
constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
|
||||
|
||||
auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatReduceAcc,
|
||||
FloatC,
|
||||
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
3, // DstVectorDim
|
||||
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(I0,
|
||||
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
|
||||
I0,
|
||||
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
{
|
||||
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
|
||||
c_shuffle_block_buf,
|
||||
c_reduce_thread_desc_mperblock_nperblock,
|
||||
make_tuple(I0, I0),
|
||||
c_reduce_thread_buf);
|
||||
|
||||
c0_thread_copy_global_to_vgpr.Run(
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c0_grid_buf,
|
||||
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c01_thread_buf);
|
||||
|
||||
// c = activation(c + bias)
|
||||
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
|
||||
[&](auto i) {
|
||||
FloatReduceAcc out;
|
||||
c_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
|
||||
c_reduce_thread_buf(i) = out;
|
||||
});
|
||||
|
||||
c1_thread_copy_global_to_vgpr.Run(
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c1_grid_buf,
|
||||
c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c01_thread_buf);
|
||||
|
||||
// c = c + c1_functior(c1)
|
||||
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
|
||||
[&](auto i) {
|
||||
c1_element_op(c01_thread_buf(i), c01_thread_buf(i));
|
||||
c_reduce_thread_buf(i) += c01_thread_buf(i);
|
||||
});
|
||||
|
||||
c_reduce_thread_copy_vgpr_to_global.Run(
|
||||
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_reduce_thread_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
|
||||
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
|
||||
auto& p_d_grid = p_ds_grid[In];
|
||||
|
||||
auto d_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto d_thread_buf =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
|
||||
d_reduce_thread_desc_mperblock.GetElementSpaceSize());
|
||||
|
||||
auto& d_in_element_op = dxs_in_element_op[In];
|
||||
|
||||
auto& d_reduce_thread_copy_vgpr_to_global =
|
||||
dxs_reduce_thread_copy_vgpr_to_global(In);
|
||||
|
||||
using DReduceOperation = remove_cvref_t<decltype(DxsReduceOperation{}[In])>;
|
||||
using ThreadwiseReduce =
|
||||
ThreadwiseReduction<FloatReduceAcc,
|
||||
decltype(c_reduce_thread_desc_mperblock_nperblock),
|
||||
decltype(d_reduce_thread_desc_mperblock),
|
||||
DReduceOperation,
|
||||
false>;
|
||||
|
||||
// Global write Gemm shuffle + reduction
|
||||
const auto d_zeroVal = DReduceOperation::GetIdentityValue();
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { d_thread_buf(I) = d_zeroVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
|
||||
d_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf);
|
||||
|
||||
// copy from VGPR to Global
|
||||
d_reduce_thread_copy_vgpr_to_global.Run(
|
||||
d_reduce_thread_desc_mblock_mperblock,
|
||||
make_tuple(I0, I0),
|
||||
d_thread_buf,
|
||||
d_grid_desc_mblock_mperblock,
|
||||
d_grid_buf);
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
d_grid_desc_mblock_mperblock,
|
||||
make_tuple(c_global_step[I0], c_global_step[I1]));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
|
||||
// move on C
|
||||
c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
|
||||
// move on C0
|
||||
c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
|
||||
c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
|
||||
// move on C1
|
||||
c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
}
|
||||
});
|
||||
} // Reduction
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsAccElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -41,7 +41,7 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsAccElementwiseOperation dxs_out_element_op,
|
||||
const DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -96,7 +96,7 @@ template <typename FloatAB,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsAccElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const DxsInElementwiseOperation& dxs_in_element_op,
|
||||
const DxsAccElementwiseOperation& dxs_out_element_op,
|
||||
const DxsReduceAccElementwiseOperation& dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
|
||||
@@ -20,7 +20,7 @@ include_directories(BEFORE
|
||||
|
||||
function(add_instance_library INSTANCE_NAME)
|
||||
message("adding instance ${INSTANCE_NAME}")
|
||||
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
|
||||
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
|
||||
target_compile_features(${INSTANCE_NAME} PUBLIC)
|
||||
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endfunction(add_instance_library INSTANCE_NAME)
|
||||
@@ -30,6 +30,7 @@ add_subdirectory(gemm_bias2d)
|
||||
add_subdirectory(gemm_bias_relu)
|
||||
add_subdirectory(gemm_bias_relu_add)
|
||||
add_subdirectory(gemm_reduce)
|
||||
add_subdirectory(gemm_bias_add_reduce)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(conv1d_fwd)
|
||||
add_subdirectory(conv2d_fwd)
|
||||
@@ -44,12 +45,12 @@ add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(conv2d_bwd_weight)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
|
||||
add_library(device_operations STATIC
|
||||
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_batched_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
|
||||
add_library(device_operations STATIC
|
||||
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_batched_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
|
||||
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_atomic_add_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_instance>
|
||||
@@ -67,14 +68,14 @@ add_library(device_operations STATIC
|
||||
add_library(composablekernels::device_operations ALIAS device_operations)
|
||||
|
||||
|
||||
set(DEV_OPS_INC_DIRS
|
||||
set(DEV_OPS_INC_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/ck/
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/
|
||||
${PROJECT_SOURCE_DIR}/external/include/
|
||||
)
|
||||
target_compile_features(device_operations PUBLIC)
|
||||
set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(device_operations PUBLIC
|
||||
target_include_directories(device_operations PUBLIC
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/utility>
|
||||
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_description>
|
||||
@@ -108,8 +109,8 @@ install(TARGETS device_operations
|
||||
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
|
||||
install(EXPORT device_operationsTargets
|
||||
FILE composable_kerneldevice_operationsTargets.cmake
|
||||
install(EXPORT device_operationsTargets
|
||||
FILE composable_kerneldevice_operationsTargets.cmake
|
||||
NAMESPACE composable_kernel::
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
|
||||
)
|
||||
|
||||
@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
|
||||
@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
|
||||
@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
|
||||
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
|
||||
@@ -59,12 +59,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
|
||||
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE
|
||||
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_gemm_bias_add_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE})
|
||||
install(TARGETS device_gemm_bias_add_reduce_instance LIBRARY DESTINATION lib)
|
||||
clang_tidy_check(device_gemm_bias_add_reduce_instance)
|
||||
@@ -0,0 +1,81 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceSum = ck::reduce::Add<F32>;
|
||||
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
|
||||
|
||||
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
|
||||
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
|
||||
using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Div, Div>;
|
||||
|
||||
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// c[m, n] = a[k, m] * b[k, n]
|
||||
using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,81 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceSum = ck::reduce::Add<F32>;
|
||||
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
|
||||
|
||||
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
|
||||
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
|
||||
using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Div, Div>;
|
||||
|
||||
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// c[m, n] = a[k, m] * b[n, k]
|
||||
using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,81 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceSum = ck::reduce::Add<F32>;
|
||||
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
|
||||
|
||||
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
|
||||
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
|
||||
using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Div, Div>;
|
||||
|
||||
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// c[m, n] = a[m, k] * b[n, k]
|
||||
using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| CLayout| AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,78 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceSum = ck::reduce::Add<F32>;
|
||||
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
|
||||
|
||||
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
|
||||
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
|
||||
using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Div, Div>;
|
||||
|
||||
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// c[m, n] = a[m, k] * b[n, k]
|
||||
using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>,
|
||||
DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = s
|
||||
>;
|
||||
|
||||
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{});
|
||||
|
||||
@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = s
|
||||
>;
|
||||
|
||||
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{});
|
||||
|
||||
@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = s
|
||||
>;
|
||||
|
||||
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{});
|
||||
|
||||
@@ -59,12 +59,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = s
|
||||
>;
|
||||
|
||||
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>>& instances)
|
||||
std::vector<
|
||||
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{});
|
||||
|
||||
@@ -29,6 +29,7 @@ set(PROFILER_SOURCE
|
||||
src/profile_gemm_bias_relu.cpp
|
||||
src/profile_gemm_bias_relu_add.cpp
|
||||
src/profile_gemm_reduce.cpp
|
||||
src/profile_gemm_bias_add_reduce.cpp
|
||||
src/profile_batched_gemm.cpp
|
||||
src/profile_conv_fwd_bias_relu.cpp
|
||||
src/profile_conv_fwd_bias_relu_add.cpp
|
||||
@@ -46,6 +47,7 @@ add_executable(ckProfiler ${PROFILER_SOURCE})
|
||||
target_link_libraries(ckProfiler PRIVATE host_tensor)
|
||||
target_link_libraries(ckProfiler PRIVATE conv_util)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
|
||||
|
||||
@@ -26,7 +26,6 @@ using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Identity, Identity>;
|
||||
|
||||
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
|
||||
DPtrsGlobal,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -260,7 +259,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
|
||||
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
&dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
|
||||
388
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
Normal file
388
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
Normal file
@@ -0,0 +1,388 @@
|
||||
#pragma once
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "device_gemm_reduce.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_gemm_instance {
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
|
||||
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
|
||||
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
|
||||
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
|
||||
using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Div, Div>;
|
||||
|
||||
using DeviceGemmBiasAddReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmBiasAddReducePtr<
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
DInElementOps,
|
||||
DOutElementOps>;
|
||||
|
||||
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
|
||||
|
||||
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
|
||||
|
||||
} // namespace device_gemm_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename C0DataType,
|
||||
typename C1DataType,
|
||||
typename DDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void profile_gemm_bias_add_reduce_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC,
|
||||
int StrideC1)
|
||||
{
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor2d =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<C0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<C1DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl;
|
||||
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl;
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
std::srand(0);
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
|
||||
bias_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
|
||||
c1_m_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
std::srand(0);
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
bias_n.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5}, num_thread);
|
||||
c1_m_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
using C1ElementOp = PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add<float>;
|
||||
using D1ReduceOp = ck::reduce::Add<float>;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic<float, float, true>;
|
||||
using UnaryIdenticElementOp =
|
||||
ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>;
|
||||
using UnarySquareElementOp =
|
||||
ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
const auto c1_element_op = C1ElementOp{};
|
||||
const auto d0_reduce_op = D0ReduceOp{};
|
||||
const auto d1_reduce_op = D1ReduceOp{};
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
DDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using ReduceAccDataType = DDataType;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ReduceAccDataType acc = static_cast<ReduceAccDataType>(c_m_n_host_result(m, n)) +
|
||||
static_cast<ReduceAccDataType>(bias_n(n));
|
||||
|
||||
ReduceAccDataType c1 = static_cast<ReduceAccDataType>(c1_m_n(m, n));
|
||||
c_element_op(acc, acc);
|
||||
c1_element_op(c1, c1);
|
||||
acc += c1;
|
||||
c_m_n_host_result(m, n) = static_cast<CDataType>(acc);
|
||||
}
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ReduceAccDataType c_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
ReduceAccDataType d0_val = 0;
|
||||
ReduceAccDataType d1_val = 0;
|
||||
|
||||
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
|
||||
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
|
||||
d0_reduce_op(d0_acc, d0_val);
|
||||
d1_reduce_op(d1_acc, d1_val);
|
||||
}
|
||||
|
||||
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
|
||||
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
|
||||
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
|
||||
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace());
|
||||
DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace());
|
||||
|
||||
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
bias_device_buf.ToDevice(bias_n.mData.data());
|
||||
c1_device_buf.ToDevice(c1_m_n.mData.data());
|
||||
|
||||
// add device GEMM instances
|
||||
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasAddReduceNoOpPtr>
|
||||
gemm_ptrs;
|
||||
|
||||
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
|
||||
is_same<CDataType, half_t>::value)
|
||||
{
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
|
||||
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
ck::tensor_operation::device::device_gemm_instance::
|
||||
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
|
||||
gemm_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
if(gemm_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
std::string best_gemm_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : gemm_ptrs)
|
||||
{
|
||||
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()),
|
||||
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()),
|
||||
&dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideC1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// init DO, D1 to 0
|
||||
d0_device_buf.SetZero();
|
||||
d1_device_buf.SetZero();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
|
||||
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N +
|
||||
sizeof(C1DataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
|
||||
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
|
||||
|
||||
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
ck::utils::check_err(d0_m_device_result.mData, d0_m_host_result.mData);
|
||||
ck::utils::check_err(d1_m_device_result.mData, d1_m_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host: ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "d0_host: ", d0_m_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "d0_device: ", d0_m_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "d1_host: ", d1_m_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "d1_device: ", d1_m_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "does not support this GEMM problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -26,7 +27,6 @@ using DInElementOps = ck::Tuple<Identity, Square>;
|
||||
using DOutElementOps = ck::Tuple<Div, Div>;
|
||||
|
||||
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
|
||||
DPtrsGlobal,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -143,7 +143,7 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
const auto d1_reduce_op = D1ReduceOp{};
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{M, M};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -155,6 +155,8 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using ReduceAccDataType = DDataType;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
@@ -165,14 +167,15 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
float d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
|
||||
ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
float c_val = ck::type_convert<float>(c_m_n_host_result(m, n));
|
||||
float d0_val = 0;
|
||||
float d1_val = 0;
|
||||
ReduceAccDataType c_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
ReduceAccDataType d0_val = 0;
|
||||
ReduceAccDataType d1_val = 0;
|
||||
|
||||
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
|
||||
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
|
||||
@@ -257,7 +260,7 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
&dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -309,13 +312,9 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
|
||||
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
|
||||
|
||||
float c_error = check_error(c_m_n_host_result, c_m_n_device_result);
|
||||
float d0_error = check_error(d0_m_host_result, d0_m_device_result);
|
||||
float d1_error = check_error(d1_m_host_result, d1_m_device_result);
|
||||
|
||||
pass = pass && (c_error < 1E-6);
|
||||
pass = pass && (d0_error < 1E-6);
|
||||
pass = pass && (d1_error < 1E-6);
|
||||
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
ck::utils::check_err(d0_m_device_result.mData, d0_m_host_result.mData);
|
||||
ck::utils::check_err(d1_m_device_result.mData, d1_m_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
|
||||
159
profiler/src/profile_gemm_bias_add_reduce.cpp
Normal file
159
profiler/src/profile_gemm_bias_add_reduce.cpp
Normal file
@@ -0,0 +1,159 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "profile_gemm_bias_add_reduce_impl.hpp"
|
||||
|
||||
int profile_gemm_bias_add_reduce(int argc, char* argv[])
|
||||
{
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
KM_KN_MN, // 2
|
||||
KM_NK_MN, // 3
|
||||
};
|
||||
|
||||
enum struct GemmReduceDataType
|
||||
{
|
||||
F32_F32_F32_F32_F32_F32_F32, // 0
|
||||
F16_F16_F16_F16_F16_F32_F32, // 1
|
||||
};
|
||||
|
||||
if(!(argc == 14 || argc == 15))
|
||||
{
|
||||
printf("arg1: tensor operation (gemm: GEMM+bias+add+Reduce)\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16)\n");
|
||||
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const int init_method = std::stoi(argv[5]);
|
||||
const bool do_log = std::stoi(argv[6]);
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
|
||||
const int M = std::stoi(argv[8]);
|
||||
const int N = std::stoi(argv[9]);
|
||||
const int K = std::stoi(argv[10]);
|
||||
|
||||
const int StrideA = std::stoi(argv[11]);
|
||||
const int StrideB = std::stoi(argv[12]);
|
||||
const int StrideC = std::stoi(argv[13]);
|
||||
const int StrideC1 = std::stoi(argv[14]);
|
||||
|
||||
if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 &&
|
||||
layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_bias_add_reduce_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
(StrideC1 < 0) ? N : StrideC1);
|
||||
}
|
||||
else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 &&
|
||||
layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_bias_add_reduce_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
(StrideC1 < 0) ? N : StrideC1);
|
||||
}
|
||||
else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 &&
|
||||
layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_bias_add_reduce_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
(StrideC1 < 0) ? N : StrideC1);
|
||||
}
|
||||
else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 &&
|
||||
layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
ck::profiler::profile_gemm_bias_add_reduce_impl<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
float,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
(StrideA < 0) ? M : StrideA,
|
||||
(StrideB < 0) ? N : StrideB,
|
||||
(StrideC < 0) ? N : StrideC,
|
||||
(StrideC1 < 0) ? N : StrideC1);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this data_type & layout is not implemented");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -11,6 +11,7 @@ int profile_gemm_bias_2d(int, char*[]);
|
||||
int profile_gemm_bias_relu(int, char*[]);
|
||||
int profile_gemm_bias_relu_add(int, char*[]);
|
||||
int profile_gemm_reduce(int, char*[]);
|
||||
int profile_gemm_bias_add_reduce(int, char*[]);
|
||||
int profile_batched_gemm(int, char*[]);
|
||||
int profile_grouped_gemm(int, char*[]);
|
||||
int profile_conv_fwd(int, char*[]);
|
||||
@@ -44,6 +45,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_gemm_reduce(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "gemm_bias_add_reduce") == 0)
|
||||
{
|
||||
return profile_gemm_bias_add_reduce(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "batched_gemm") == 0)
|
||||
{
|
||||
return profile_batched_gemm(argc, argv);
|
||||
|
||||
Reference in New Issue
Block a user