mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
@@ -25,115 +25,76 @@ struct PassThrough
|
||||
|
||||
struct Relu
|
||||
{
|
||||
float alpha = 0.1;
|
||||
|
||||
// ReLU
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T operator()(T v) const
|
||||
{
|
||||
T tmp = alpha * v;
|
||||
return tmp > 0 ? tmp : 0;
|
||||
return v > 0 ? v : 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AOp = PassThrough;
|
||||
using BOp = PassThrough;
|
||||
using COp = Relu;
|
||||
|
||||
// Compilation parameters for NT problem
|
||||
// clang-format off
|
||||
using DeviceGemmInstance =
|
||||
//#########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//#########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
ck::tensor_operation::device::DeviceGemmXdl< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
|
||||
// clang-format on
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmInstance;
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmInstance<ck::half_t,
|
||||
ck::half_t,
|
||||
ck::half_t,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
static void host_verify(const Tensor<AType>& a_m_k,
|
||||
const Tensor<BType>& b_k_n,
|
||||
Tensor<CType>& c_m_n,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CElementwiseOperation& c_element_op)
|
||||
{
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a_m_k.mDesc.GetLengths()[1];
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
double v = 0;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
|
||||
static_cast<const double>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
|
||||
using AOp = AElementwiseOperation;
|
||||
using BOp = BElementwiseOperation;
|
||||
using COp = CElementwiseOperation;
|
||||
c_m_n(m, n) = c_element_op(v);
|
||||
};
|
||||
|
||||
// Compilation parameters for NT problem
|
||||
// clang-format off
|
||||
using type =
|
||||
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
ck::tensor_operation::device::DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>;
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmInstance<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
ck::tensor_layout::gemm::ColumnMajor,
|
||||
ck::tensor_layout::gemm::RowMajor,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using AOp = AElementwiseOperation;
|
||||
using BOp = BElementwiseOperation;
|
||||
using COp = CElementwiseOperation;
|
||||
|
||||
// Compilation parameters for NT problem
|
||||
// clang-format off
|
||||
using type =
|
||||
//########################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
|
||||
//########################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
|
||||
//########################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
ck::tensor_operation::device::DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, AOp, BOp, COp, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>;
|
||||
// clang-format on
|
||||
};
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn,
|
||||
c_m_n.mDesc.GetLengths()[0],
|
||||
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 4)
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const bool do_verification = std::stoi(argv[1]);
|
||||
const int init_method = std::stoi(argv[2]);
|
||||
const int nrepeat = std::stoi(argv[3]);
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
@@ -144,15 +105,34 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
// matrix data type
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
if(argc == 4)
|
||||
{
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
|
||||
// matrix layout
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
@@ -198,16 +178,7 @@ int main(int argc, char* argv[])
|
||||
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
// do GEMM
|
||||
auto gemm = typename DeviceGemmInstance<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Relu>::type{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
@@ -218,9 +189,9 @@ int main(int argc, char* argv[])
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
Relu{});
|
||||
AOp{},
|
||||
BOp{},
|
||||
COp{});
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -233,7 +204,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N;
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -246,7 +217,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, Relu{});
|
||||
host_verify(a_m_k, b_k_n, c_m_n_host_result, AOp{}, BOp{}, COp{});
|
||||
|
||||
check_error(c_m_n_host_result, c_m_n_device_result);
|
||||
}
|
||||
|
||||
@@ -20,10 +20,42 @@
|
||||
// 0 in the "n" dimension
|
||||
// assume C1 and C have same layout C
|
||||
|
||||
struct BiasReluAdd
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
|
||||
{
|
||||
float b = v0 + v1;
|
||||
float c = b > 0 ? b : 0;
|
||||
float d = c + v2;
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
|
||||
{
|
||||
#if 0
|
||||
float a = v1 + v0;
|
||||
float b = max(a, float(0));
|
||||
float c = b + v2;
|
||||
|
||||
return c;
|
||||
#else
|
||||
float a = v1 + v2;
|
||||
float b = v2;
|
||||
|
||||
float c = (v0 > -v1) ? a + v0 : v2;
|
||||
|
||||
return c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// v0 is from A * B
|
||||
// v1 is from C0
|
||||
// v2 is from C1
|
||||
struct BiasReluAdd
|
||||
struct BiasLeakyReluAdd
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
|
||||
@@ -51,7 +83,7 @@ struct BiasReluAdd
|
||||
}
|
||||
};
|
||||
|
||||
struct BiasRelu
|
||||
struct BiasLeakyRelu
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
__host__ constexpr float operator()(float v0, T1 v1, T2) const
|
||||
@@ -99,7 +131,7 @@ struct BiasAdd
|
||||
}
|
||||
#elif 0
|
||||
float alpha = 0.1;
|
||||
float beta = 0.2;
|
||||
float beta = 0.2;
|
||||
float gamma = 0.3;
|
||||
|
||||
// wrong result
|
||||
|
||||
@@ -23,7 +23,7 @@ struct PassThrough
|
||||
}
|
||||
};
|
||||
|
||||
struct BiasReluAdd
|
||||
struct BiasLeakyReluAdd
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
|
||||
@@ -97,7 +97,39 @@ struct BiasReluAdd
|
||||
}
|
||||
};
|
||||
|
||||
struct BiasRelu
|
||||
struct BiasReluAdd
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
__host__ constexpr float operator()(float v0, T1 v1, T2 v2) const
|
||||
{
|
||||
float b = v0 + v1;
|
||||
float c = b > 0 ? b : 0;
|
||||
float d = c + v2;
|
||||
|
||||
return d;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
__device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
|
||||
{
|
||||
#if 0
|
||||
float a = v1 + v0;
|
||||
float b = max(a, float(0));
|
||||
float c = b + v2;
|
||||
|
||||
return c;
|
||||
#else
|
||||
float a = v1 + v2;
|
||||
float b = v2;
|
||||
|
||||
float c = (v0 > -v1) ? a + v0 : v2;
|
||||
|
||||
return c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct BiasLeakyRelu
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
__host__ constexpr float operator()(float v0, T1 v1, T2) const
|
||||
@@ -377,6 +409,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
|
||||
sizeof(WeiDataType) * (K * C * Y * X) +
|
||||
sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) +
|
||||
sizeof(OutDataType) * (N * K * Ho * Wo);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
Reference in New Issue
Block a user