mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Add contraction profiler and tests (#701)
* Add contraction profiler and tests
* Build and style fixes
* Allow to use any elementwise operator for ref_contraction
* Introduce profile_contraction_scale and profile_contraction_bilinear
* Make ref_contraction generic and extend interface tests
* Stylistic minor fixes
* Extend test_contraction_interface
[ROCm/composable_kernel commit: 642d5e9155]
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -385,22 +251,22 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -385,22 +251,22 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -368,22 +234,23 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -368,22 +234,23 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<CDataType>& c_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
c_ms_ns_{c_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<CDataType>& c_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const ck::index_t K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const ck::index_t K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(ck::index_t k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(ck::index_t k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
arg.c_ms_ns_(m0, m1, n0, n1) = v_acc;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.c_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.c_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.c_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.c_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<CDataType>& c_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, c_ms_ns, a_element_op, b_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -46,3 +46,33 @@ out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
|
||||
....
|
||||
Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s
|
||||
```
|
||||
|
||||
## Profile contraction kernels
|
||||
```bash
|
||||
#arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear)
|
||||
#arg2: data type (0: fp32; 1: f64)\n"
|
||||
#arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
|
||||
# 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
|
||||
# 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1];
|
||||
# 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1])
|
||||
#arg4: verification (0: no; 1: yes)
|
||||
#arg5: initialization (0: no init; 1: integer value; 2: decimal value)
|
||||
#arg6: print tensor value (0: no; 1: yes)
|
||||
#arg7: time kernel (0: no, 1: yes)
|
||||
#arg8 and arg9: alpha and beta
|
||||
#arg10 to 15: M0, M1, N0, N1, K0, K1
|
||||
#arg16 to 31: Strides for A, B, D and E (skip for default)
|
||||
|
||||
################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1
|
||||
./bin/ckProfiler contraction_bilinear 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128
|
||||
```
|
||||
|
||||
Result (MI100)
|
||||
```bash
|
||||
a_m_k: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
|
||||
b_k_n: dim 4, lengths {128, 128, 128, 128}, strides {128, 1, 2097152, 16384}
|
||||
d_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
|
||||
e_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
|
||||
....
|
||||
Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s
|
||||
```
|
||||
|
||||
345
profiler/include/profiler/profile_contraction_impl.hpp
Normal file
345
profiler/include/profiler/profile_contraction_impl.hpp
Normal file
@@ -0,0 +1,345 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CDELayout,
|
||||
typename DataType,
|
||||
typename DTupleDataType,
|
||||
typename CDElementOp>
|
||||
int profile_contraction_impl(ck::index_t do_verification,
|
||||
ck::index_t init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
CDElementOp cde_element_op,
|
||||
const std::vector<ck::index_t>& M,
|
||||
const std::vector<ck::index_t>& N,
|
||||
const std::vector<ck::index_t>& K,
|
||||
const std::vector<ck::index_t>& StridesA,
|
||||
const std::vector<ck::index_t>& StridesB,
|
||||
const std::vector<ck::index_t>& StridesE,
|
||||
const std::vector<ck::index_t>& StridesD)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor = [](const std::vector<ck::index_t>& dims01,
|
||||
const std::vector<ck::index_t>& dims23,
|
||||
const std::vector<ck::index_t>& strides) {
|
||||
std::vector<std::size_t> dims_szt(dims01.begin(), dims01.end());
|
||||
dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end());
|
||||
std::vector<std::size_t> strides_szt(strides.begin(), strides.end());
|
||||
|
||||
return HostTensorDescriptor(dims_szt, strides);
|
||||
};
|
||||
|
||||
Tensor<DataType> a_m_k(f_host_tensor_descriptor(M, K, StridesA));
|
||||
Tensor<DataType> b_k_n(f_host_tensor_descriptor(K, N, StridesB));
|
||||
Tensor<DataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
|
||||
Tensor<DataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE));
|
||||
Tensor<DataType> d_m_n(f_host_tensor_descriptor(M, N, StridesD));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
e_device_buf.SetZero();
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
|
||||
const std::vector<index_t> a_ms_ks_lengths = {M[0], M[1], K[0], K[1]};
|
||||
const std::vector<index_t> b_ns_ks_lengths = {N[0], N[1], K[0], K[1]};
|
||||
const std::vector<index_t> e_ms_ns_lengths = {M[0], M[1], N[0], N[1]};
|
||||
const std::vector<index_t> d_m_n_lengths = {M[0], M[1], N[0], N[1]};
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
|
||||
constexpr ck::index_t NumDim = 2;
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<NumDim,
|
||||
NumDim,
|
||||
NumDim,
|
||||
DataType,
|
||||
DataType,
|
||||
DTupleDataType,
|
||||
DataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
// Run reference op
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDim,
|
||||
NumDim,
|
||||
NumDim,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_op = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
Tensor<DataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE));
|
||||
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(size_t m0 = 0; m0 < e_m_n_host_result.mDesc.GetLengths()[0]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_m_n_host_result.mDesc.GetLengths()[1]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_m_n_host_result.mDesc.GetLengths()[2]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_m_n_host_result.mDesc.GetLengths()[3]; ++n1)
|
||||
{
|
||||
if constexpr(is_same<CDElementOp, Bilinear>::value)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
|
||||
c_m_n_host_result(m0, m1, n0, n1),
|
||||
d_m_n(m0, m1, n0, n1));
|
||||
}
|
||||
else if constexpr(is_same<CDElementOp, Scale>::value)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m0, m1, n0, n1),
|
||||
c_m_n_host_result(m0, m1, n0, n1));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert("Unsupported CDElementOp in contraction profiler.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device op instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
std::unique_ptr<tensor_operation::device::BaseArgument> argument_ptr;
|
||||
if constexpr(is_same<CDElementOp, Bilinear>::value)
|
||||
{
|
||||
argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<DataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
|
||||
a_ms_ks_lengths,
|
||||
StridesA,
|
||||
b_ns_ks_lengths,
|
||||
StridesB,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_m_n_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{StridesD},
|
||||
e_ms_ns_lengths,
|
||||
StridesE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
else if constexpr(is_same<CDElementOp, Scale>::value)
|
||||
{
|
||||
argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<DataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 0>{},
|
||||
static_cast<DataType*>(e_device_buf.GetDeviceBuffer()),
|
||||
a_ms_ks_lengths,
|
||||
StridesA,
|
||||
b_ns_ks_lengths,
|
||||
StridesB,
|
||||
std::array<std::vector<ck::index_t>, 0>{},
|
||||
std::array<std::vector<ck::index_t>, 0>{},
|
||||
e_ms_ns_lengths,
|
||||
StridesE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert("Unsupported CDElementOp in contraction profiler.");
|
||||
}
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
auto nelems_m = M[0] * M[1];
|
||||
auto nelems_n = N[0] * N[1];
|
||||
auto nelems_k = K[0] * K[1];
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init C to zero before profiling next kernel
|
||||
e_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float avg_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * nelems_m * nelems_n * nelems_k;
|
||||
|
||||
std::size_t num_btype = sizeof(DataType) * nelems_m * nelems_k +
|
||||
sizeof(DataType) * nelems_k * nelems_n +
|
||||
sizeof(DataType) * nelems_m * nelems_n;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
float threshold =
|
||||
static_cast<DataType>(nelems_k) * std::numeric_limits<DataType>::epsilon();
|
||||
pass = pass & ck::utils::check_err(e_m_n_device_result,
|
||||
e_m_n_host_result,
|
||||
"Error: incorrect results!",
|
||||
threshold,
|
||||
threshold);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<DataType, float>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f32";
|
||||
}
|
||||
else if constexpr(is_same<DataType, double>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f64";
|
||||
}
|
||||
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = ColumnMajor";
|
||||
}
|
||||
|
||||
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = ColumnMajor";
|
||||
}
|
||||
|
||||
if constexpr(is_same<CDELayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " CDELayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<CDELayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " CDELayout = ColumnMajor";
|
||||
}
|
||||
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StridesA = " << StridesA
|
||||
<< " StridesB = " << StridesB << " StridesE = " << StridesE << " : " << best_avg_time
|
||||
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
|
||||
<< best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
51
profiler/include/profiler/profile_contraction_utils.hpp
Normal file
51
profiler/include/profiler/profile_contraction_utils.hpp
Normal file
@@ -0,0 +1,51 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
enum struct ContractionMatrixLayout
|
||||
{
|
||||
MK_KN_MN_MN, // 0
|
||||
MK_NK_MN_MN, // 1
|
||||
KM_KN_MN_MN, // 2
|
||||
KM_NK_MN_MN, // 3
|
||||
};
|
||||
|
||||
enum struct ContractionDataType
|
||||
{
|
||||
F32_F32_F32_F32, // 0
|
||||
F64_F64_F64_F64, // 1
|
||||
};
|
||||
|
||||
inline void collect_index_params(char* argv[],
|
||||
std::vector<ck::index_t>& params,
|
||||
const ck::index_t from,
|
||||
const ck::index_t num)
|
||||
{
|
||||
for(ck::index_t p = from; p < from + num; p++)
|
||||
params.push_back(std::stoi(argv[p]));
|
||||
}
|
||||
|
||||
// Defualt strides for row-major: {Dim1 * Dim2 * Dim3, Dim2 * Dim3, Dim3, 1}
|
||||
// Defualt strides for column-major: {Dim1, 1, Dim0 * Dim1 * Dim3, Dim0 * Dim1}
|
||||
inline void
|
||||
assign_default_strides(Row, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
|
||||
{
|
||||
strides = {dims[1] * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1};
|
||||
}
|
||||
|
||||
inline void
|
||||
assign_default_strides(Col, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
|
||||
{
|
||||
strides = {dims[1], 1, dims[0] * dims[1] * dims[3], dims[0] * dims[1]};
|
||||
}
|
||||
@@ -30,6 +30,8 @@ set(PROFILER_SOURCES
|
||||
profile_batchnorm_bwd.cpp
|
||||
profile_batchnorm_infer.cpp
|
||||
profile_grouped_gemm_fastgelu.cpp
|
||||
profile_contraction_bilinear.cpp
|
||||
profile_contraction_scale.cpp
|
||||
)
|
||||
|
||||
set(PROFILER_EXECUTABLE ckProfiler)
|
||||
@@ -70,4 +72,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
|
||||
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
|
||||
|
||||
165
profiler/src/profile_contraction_bilinear.cpp
Normal file
165
profiler/src/profile_contraction_bilinear.cpp
Normal file
@@ -0,0 +1,165 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
#include "profiler/profile_contraction_impl.hpp"
|
||||
#include "profiler/profile_contraction_utils.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
#define OP_NAME "contraction_bilinear"
|
||||
#define OP_DESC "CONTRACTION+Bilinear"
|
||||
|
||||
static void print_helper_msg()
|
||||
{
|
||||
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
|
||||
<< "arg2: data type (0: fp32; 1: f64)\n"
|
||||
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
|
||||
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
|
||||
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
|
||||
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
|
||||
<< " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + "
|
||||
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
|
||||
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
|
||||
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
|
||||
<< "arg4: verification (0: no; 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal "
|
||||
<< "value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0: no, 1: yes)\n"
|
||||
<< "arg8 and arg9: alpha and beta\n"
|
||||
<< "arg10 to 15: M0, M1, N0, N1, K0, K1\n"
|
||||
<< "arg16 to 31: Strides for A, B, D and E (skip for default)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
int profile_contraction_bilinear(int argc, char* argv[])
|
||||
{
|
||||
const bool default_strides = argc == 16;
|
||||
|
||||
if(argc != 32 && argc != 16)
|
||||
{
|
||||
print_helper_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const ck::index_t init_method = std::stoi(argv[5]);
|
||||
const bool do_log = std::stoi(argv[6]);
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
const float alpha = std::stof(argv[8]);
|
||||
const float beta = std::stof(argv[9]);
|
||||
|
||||
std::vector<ck::index_t> M;
|
||||
std::vector<ck::index_t> N;
|
||||
std::vector<ck::index_t> K;
|
||||
const ck::index_t dims_arg_num = 10;
|
||||
collect_index_params(argv, M, dims_arg_num, 2);
|
||||
collect_index_params(argv, N, dims_arg_num + 2, 2);
|
||||
collect_index_params(argv, K, dims_arg_num + 4, 2);
|
||||
|
||||
std::vector<ck::index_t> StridesA;
|
||||
std::vector<ck::index_t> StridesB;
|
||||
std::vector<ck::index_t> StridesE;
|
||||
std::vector<ck::index_t> StridesD;
|
||||
if(!default_strides)
|
||||
{
|
||||
collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
|
||||
collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
|
||||
collect_index_params(argv, StridesE, dims_arg_num + 14, 4);
|
||||
collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
|
||||
}
|
||||
|
||||
using F32 = float;
|
||||
using F64 = double;
|
||||
|
||||
auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) {
|
||||
using ALayout = decltype(a_layout);
|
||||
using BLayout = decltype(b_layout);
|
||||
using CDELayout = decltype(cde_layout);
|
||||
|
||||
using DataType = decltype(type);
|
||||
|
||||
if(default_strides)
|
||||
{
|
||||
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
|
||||
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
|
||||
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
|
||||
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
|
||||
}
|
||||
bool pass = ck::profiler::profile_contraction_impl<ALayout,
|
||||
BLayout,
|
||||
CDELayout,
|
||||
DataType,
|
||||
ck::Tuple<DataType>,
|
||||
Bilinear>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Bilinear{alpha, beta},
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StridesA,
|
||||
StridesB,
|
||||
StridesE,
|
||||
StridesD);
|
||||
|
||||
return pass;
|
||||
};
|
||||
|
||||
if(data_type == ContractionDataType::F32_F32_F32_F32 &&
|
||||
layout == ContractionMatrixLayout::MK_KN_MN_MN)
|
||||
{
|
||||
return profile(Row{}, Row{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
|
||||
layout == ContractionMatrixLayout::MK_NK_MN_MN)
|
||||
{
|
||||
return profile(Row{}, Col{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
|
||||
layout == ContractionMatrixLayout::KM_KN_MN_MN)
|
||||
{
|
||||
return profile(Col{}, Row{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
|
||||
layout == ContractionMatrixLayout::KM_NK_MN_MN)
|
||||
{
|
||||
return profile(Col{}, Col{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
|
||||
layout == ContractionMatrixLayout::MK_KN_MN_MN)
|
||||
{
|
||||
return profile(Row{}, Row{}, Row{}, F64{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
|
||||
layout == ContractionMatrixLayout::MK_NK_MN_MN)
|
||||
{
|
||||
return profile(Row{}, Col{}, Row{}, F64{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
|
||||
layout == ContractionMatrixLayout::KM_KN_MN_MN)
|
||||
{
|
||||
return profile(Col{}, Row{}, Row{}, F64{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
|
||||
layout == ContractionMatrixLayout::KM_NK_MN_MN)
|
||||
{
|
||||
return profile(Col{}, Col{}, Row{}, F64{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear);
|
||||
162
profiler/src/profile_contraction_scale.cpp
Normal file
162
profiler/src/profile_contraction_scale.cpp
Normal file
@@ -0,0 +1,162 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
#include "profiler/profile_contraction_impl.hpp"
|
||||
#include "profiler/profile_contraction_utils.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
#define OP_NAME "contraction_scale"
|
||||
#define OP_DESC "CONTRACTION+Scale"
|
||||
|
||||
static void print_helper_msg()
|
||||
{
|
||||
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
|
||||
<< "arg2: data type (0: fp32; 1: f64)\n"
|
||||
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
|
||||
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
|
||||
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
|
||||
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
|
||||
<< " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + "
|
||||
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n"
|
||||
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
|
||||
"D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n"
|
||||
<< "arg4: verification (0: no; 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal "
|
||||
<< "value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0: no, 1: yes)\n"
|
||||
<< "arg8: alpha\n"
|
||||
<< "arg9 to 14: M0, M1, N0, N1, K0, K1\n"
|
||||
<< "arg15 to 30: Strides for A, B, D and E (skip for default)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
int profile_contraction_scale(int argc, char* argv[])
|
||||
{
|
||||
const bool default_strides = argc == 15;
|
||||
|
||||
if(argc != 31 && argc != 15)
|
||||
{
|
||||
print_helper_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const ck::index_t init_method = std::stoi(argv[5]);
|
||||
const bool do_log = std::stoi(argv[6]);
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
const float alpha = std::stof(argv[8]);
|
||||
|
||||
std::vector<ck::index_t> M;
|
||||
std::vector<ck::index_t> N;
|
||||
std::vector<ck::index_t> K;
|
||||
const ck::index_t dims_arg_num = 9;
|
||||
collect_index_params(argv, M, dims_arg_num, 2);
|
||||
collect_index_params(argv, N, dims_arg_num + 2, 2);
|
||||
collect_index_params(argv, K, dims_arg_num + 4, 2);
|
||||
|
||||
std::vector<ck::index_t> StridesA;
|
||||
std::vector<ck::index_t> StridesB;
|
||||
std::vector<ck::index_t> StridesE;
|
||||
std::vector<ck::index_t> StridesD;
|
||||
if(!default_strides)
|
||||
{
|
||||
collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
|
||||
collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
|
||||
collect_index_params(argv, StridesE, dims_arg_num + 14, 4);
|
||||
collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
|
||||
}
|
||||
|
||||
using F32 = float;
|
||||
using F64 = double;
|
||||
|
||||
auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) {
|
||||
using ALayout = decltype(a_layout);
|
||||
using BLayout = decltype(b_layout);
|
||||
using CDELayout = decltype(cde_layout);
|
||||
|
||||
using DataType = decltype(type);
|
||||
|
||||
if(default_strides)
|
||||
{
|
||||
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
|
||||
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
|
||||
assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]});
|
||||
assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]});
|
||||
}
|
||||
|
||||
bool pass = ck::profiler::
|
||||
profile_contraction_impl<ALayout, BLayout, CDELayout, DataType, ck::Tuple<>, Scale>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Scale{alpha},
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StridesA,
|
||||
StridesB,
|
||||
StridesE,
|
||||
StridesD);
|
||||
|
||||
return pass;
|
||||
};
|
||||
|
||||
if(data_type == ContractionDataType::F32_F32_F32_F32 &&
|
||||
layout == ContractionMatrixLayout::MK_KN_MN_MN)
|
||||
{
|
||||
return profile(Row{}, Row{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
|
||||
layout == ContractionMatrixLayout::MK_NK_MN_MN)
|
||||
{
|
||||
return profile(Row{}, Col{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
|
||||
layout == ContractionMatrixLayout::KM_KN_MN_MN)
|
||||
{
|
||||
return profile(Col{}, Row{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
|
||||
layout == ContractionMatrixLayout::KM_NK_MN_MN)
|
||||
{
|
||||
return profile(Col{}, Col{}, Row{}, F32{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
|
||||
layout == ContractionMatrixLayout::MK_KN_MN_MN)
|
||||
{
|
||||
return profile(Row{}, Row{}, Row{}, F64{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
|
||||
layout == ContractionMatrixLayout::MK_NK_MN_MN)
|
||||
{
|
||||
return profile(Row{}, Col{}, Row{}, F64{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
|
||||
layout == ContractionMatrixLayout::KM_KN_MN_MN)
|
||||
{
|
||||
return profile(Col{}, Row{}, Row{}, F64{});
|
||||
}
|
||||
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
|
||||
layout == ContractionMatrixLayout::KM_NK_MN_MN)
|
||||
{
|
||||
return profile(Col{}, Col{}, Row{}, F64{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale);
|
||||
@@ -56,6 +56,7 @@ add_subdirectory(normalization)
|
||||
add_subdirectory(data_type)
|
||||
add_subdirectory(elementwise_normalization)
|
||||
add_subdirectory(batchnorm)
|
||||
add_subdirectory(contraction)
|
||||
if(GPU_TARGETS MATCHES "gfx1100")
|
||||
add_subdirectory(wmma_op)
|
||||
endif()
|
||||
|
||||
4
test/contraction/CMakeLists.txt
Normal file
4
test/contraction/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
add_gtest_executable(test_contraction test_contraction.cpp)
|
||||
add_gtest_executable(test_contraction_interface test_contraction_interface.cpp)
|
||||
target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
|
||||
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
|
||||
138
test/contraction/test_contraction.cpp
Normal file
138
test/contraction/test_contraction.cpp
Normal file
@@ -0,0 +1,138 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_contraction_impl.hpp"
|
||||
|
||||
using F32 = float;
|
||||
using F64 = double;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
struct MemoryParams
|
||||
{
|
||||
std::vector<ck::index_t> M;
|
||||
std::vector<ck::index_t> N;
|
||||
std::vector<ck::index_t> K;
|
||||
std::vector<ck::index_t> StridesA;
|
||||
std::vector<ck::index_t> StridesB;
|
||||
std::vector<ck::index_t> StridesC;
|
||||
std::vector<ck::index_t> StridesD;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestContraction : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CDLayout = std::tuple_element_t<2, Tuple>;
|
||||
using DataType = std::tuple_element_t<3, Tuple>;
|
||||
using DTupleDataType = std::tuple_element_t<4, Tuple>;
|
||||
using CDElementOp = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
std::vector<MemoryParams> list_of_memory_params = {{{32, 32},
|
||||
{32, 32},
|
||||
{32, 32},
|
||||
{32768, 1024, 32, 1},
|
||||
{32768, 1024, 32, 1},
|
||||
{32768, 1024, 32, 1},
|
||||
{32768, 1024, 32, 1}},
|
||||
{{16, 16},
|
||||
{32, 32},
|
||||
{16, 16},
|
||||
{4096, 256, 16, 1},
|
||||
{16, 1, 8192, 256},
|
||||
{16384, 1024, 32, 1},
|
||||
{16384, 1024, 32, 1}}};
|
||||
|
||||
std::vector<ck::index_t> init_methods = {0, 1, 2};
|
||||
std::unique_ptr<CDElementOp> p_cd_element_op;
|
||||
void Run()
|
||||
{
|
||||
for(auto& memory_params : list_of_memory_params)
|
||||
{
|
||||
for(const ck::index_t init_method : init_methods)
|
||||
{
|
||||
bool pass =
|
||||
ck::profiler::profile_contraction_impl<ALayout,
|
||||
BLayout,
|
||||
CDLayout,
|
||||
DataType,
|
||||
DTupleDataType,
|
||||
CDElementOp>(true /*do_verification*/,
|
||||
init_method,
|
||||
false /*do_logs*/,
|
||||
false /*time_kernel*/,
|
||||
*p_cd_element_op,
|
||||
memory_params.M,
|
||||
memory_params.N,
|
||||
memory_params.K,
|
||||
memory_params.StridesA,
|
||||
memory_params.StridesB,
|
||||
memory_params.StridesC,
|
||||
memory_params.StridesD);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestContractionScale : public TestContraction<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestContractionBilinear : public TestContraction<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using BilinearKernelTypes =
|
||||
::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Row, Col, Row, F32, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Col, Row, Row, F32, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Col, Col, Row, F32, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Row, Row, Row, F64, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Row, Col, Row, F64, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Col, Row, Row, F64, ck::Tuple<F32>, Bilinear>,
|
||||
std::tuple<Col, Col, Row, F64, ck::Tuple<F32>, Bilinear>>;
|
||||
|
||||
using ScaleKernelTypes = ::testing::Types<std::tuple<Row, Row, Row, F32, ck::Tuple<>, Scale>,
|
||||
std::tuple<Row, Col, Row, F32, ck::Tuple<>, Scale>,
|
||||
std::tuple<Col, Row, Row, F32, ck::Tuple<>, Scale>,
|
||||
std::tuple<Col, Col, Row, F32, ck::Tuple<>, Scale>,
|
||||
std::tuple<Row, Row, Row, F64, ck::Tuple<>, Scale>,
|
||||
std::tuple<Row, Col, Row, F64, ck::Tuple<>, Scale>,
|
||||
std::tuple<Col, Row, Row, F64, ck::Tuple<>, Scale>,
|
||||
std::tuple<Col, Col, Row, F64, ck::Tuple<>, Scale>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes);
|
||||
TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes);
|
||||
|
||||
TYPED_TEST(TestContractionBilinear, bilinear)
|
||||
{
|
||||
this->p_cd_element_op = std::make_unique<Bilinear>(1.f, 1.f);
|
||||
this->Run();
|
||||
this->p_cd_element_op = std::make_unique<Bilinear>(-0.5f, 0.5f);
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestContractionScale, scale)
|
||||
{
|
||||
this->p_cd_element_op = std::make_unique<Scale>(1.f);
|
||||
this->Run();
|
||||
this->p_cd_element_op = std::make_unique<Scale>(0.5f);
|
||||
this->Run();
|
||||
}
|
||||
195
test/contraction/test_contraction_interface.cpp
Normal file
195
test/contraction/test_contraction_interface.cpp
Normal file
@@ -0,0 +1,195 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F32 = float;
|
||||
using F64 = double;
|
||||
|
||||
template <ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t CDEBlockTransferScalarPerVector>
|
||||
class ContractionInstanceWrapper
|
||||
{
|
||||
public:
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr ck::index_t NumDim = 2;
|
||||
// clang-format off
|
||||
using ContractionDeviceInstance = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple<F32>, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>;
|
||||
// clang-format on
|
||||
|
||||
bool isSupported(std::vector<ck::index_t>& ADims,
|
||||
std::vector<ck::index_t>& BDims,
|
||||
std::vector<ck::index_t>& DDims,
|
||||
std::vector<ck::index_t>& EDims,
|
||||
std::vector<ck::index_t>& AStrides,
|
||||
std::vector<ck::index_t>& BStrides,
|
||||
std::vector<ck::index_t>& DStrides,
|
||||
std::vector<ck::index_t>& EStrides) const
|
||||
{
|
||||
auto contraction = ContractionDeviceInstance{};
|
||||
|
||||
auto argument = contraction.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
std::array<const void*, 1>{nullptr},
|
||||
nullptr,
|
||||
ADims,
|
||||
AStrides,
|
||||
BDims,
|
||||
BStrides,
|
||||
std::array<std::vector<ck::index_t>, 1>{DDims},
|
||||
std::array<std::vector<ck::index_t>, 1>{DStrides},
|
||||
EDims,
|
||||
EStrides,
|
||||
Pass{},
|
||||
Pass{},
|
||||
Bilinear{1.f, 1.f});
|
||||
return contraction.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DataTypeA,
|
||||
typename DataTypeB,
|
||||
typename DataTypeC,
|
||||
typename DataTypeD,
|
||||
ck::index_t NumDim>
|
||||
class ContractionDeviceOpWrapper
|
||||
{
|
||||
|
||||
protected:
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD<NumDim,
|
||||
NumDim,
|
||||
NumDim,
|
||||
DataTypeA,
|
||||
DataTypeB,
|
||||
ck::Tuple<DataTypeC>,
|
||||
DataTypeD,
|
||||
Pass,
|
||||
Pass,
|
||||
Bilinear>;
|
||||
|
||||
public:
|
||||
bool IsSupportedInstance(std::vector<ck::index_t>& Dims,
|
||||
std::vector<ck::index_t>& Strides) const
|
||||
{
|
||||
|
||||
bool supported = false;
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(nullptr,
|
||||
nullptr,
|
||||
std::array<const void*, 1>{nullptr},
|
||||
nullptr,
|
||||
Dims,
|
||||
Strides,
|
||||
Dims,
|
||||
Strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{Dims},
|
||||
std::array<std::vector<ck::index_t>, 1>{Strides},
|
||||
Dims,
|
||||
Strides,
|
||||
Pass{},
|
||||
Pass{},
|
||||
Bilinear{1.f, 1.f});
|
||||
|
||||
supported = supported || op_ptr->IsSupportedArgument(argument_ptr.get());
|
||||
}
|
||||
return supported;
|
||||
}
|
||||
};
|
||||
|
||||
TEST(TestContractionInterface, IncorrectNumDims)
|
||||
{
|
||||
std::vector<std::vector<ck::index_t>> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}};
|
||||
std::vector<std::vector<ck::index_t>> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}};
|
||||
ContractionDeviceOpWrapper<F32, F32, F32, F32, 1> wrapper_1d;
|
||||
ContractionDeviceOpWrapper<F32, F32, F32, F32, 2> wrapper_2d;
|
||||
ContractionDeviceOpWrapper<F32, F32, F32, F32, 3> wrapper_3d;
|
||||
EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0]));
|
||||
EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1]));
|
||||
EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2]));
|
||||
}
|
||||
|
||||
TEST(TestContractionInterface, IncorrectDataTypes)
|
||||
{
|
||||
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
|
||||
std::vector<ck::index_t> Strides = {64, 16, 4, 1};
|
||||
ContractionDeviceOpWrapper<F32, F32, F64, F64, 2> wrapper_1;
|
||||
ContractionDeviceOpWrapper<F64, F64, F32, F32, 2> wrapper_2;
|
||||
EXPECT_FALSE(wrapper_1.IsSupportedInstance(Dims, Strides));
|
||||
EXPECT_FALSE(wrapper_2.IsSupportedInstance(Dims, Strides));
|
||||
}
|
||||
|
||||
TEST(TestContractionSupportedArgs, ABMemoryAccess)
|
||||
{
|
||||
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
|
||||
std::vector<ck::index_t> Strides = {64, 16, 4, 1};
|
||||
std::vector<ck::index_t> StridesM1 = {4, 1, 64, 16};
|
||||
std::vector<ck::index_t> StridesK1 = {64, 16, 4, 1};
|
||||
std::vector<ck::index_t> InvalidStrides = {4, 4, 4, 4};
|
||||
// Memory access to A
|
||||
ContractionInstanceWrapper<1, 2, 4> wrapperA1;
|
||||
ContractionInstanceWrapper<2, 2, 4> wrapperA2;
|
||||
EXPECT_FALSE(
|
||||
wrapperA1.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides));
|
||||
EXPECT_FALSE(
|
||||
wrapperA2.isSupported(Dims, Dims, Dims, Dims, InvalidStrides, Strides, Strides, Strides));
|
||||
EXPECT_TRUE(
|
||||
wrapperA1.isSupported(Dims, Dims, Dims, Dims, StridesM1, Strides, Strides, Strides));
|
||||
EXPECT_TRUE(
|
||||
wrapperA2.isSupported(Dims, Dims, Dims, Dims, StridesK1, Strides, Strides, Strides));
|
||||
// Memory access to B
|
||||
ContractionInstanceWrapper<2, 1, 4> wrapperB1;
|
||||
ContractionInstanceWrapper<2, 2, 4> wrapperB2;
|
||||
EXPECT_FALSE(
|
||||
wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides));
|
||||
EXPECT_FALSE(
|
||||
wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, InvalidStrides, Strides, Strides));
|
||||
EXPECT_TRUE(
|
||||
wrapperB1.isSupported(Dims, Dims, Dims, Dims, Strides, StridesM1, Strides, Strides));
|
||||
EXPECT_TRUE(
|
||||
wrapperB2.isSupported(Dims, Dims, Dims, Dims, Strides, StridesK1, Strides, Strides));
|
||||
}
|
||||
|
||||
TEST(TestContractionSupportedArgs, DEMemoryAccess)
|
||||
{
|
||||
std::vector<ck::index_t> Dims = {4, 4, 4, 4};
|
||||
std::vector<ck::index_t> Strides = {64, 16, 4, 1};
|
||||
std::vector<ck::index_t> InvalidStrides = {64, 16, 1, 4};
|
||||
ContractionInstanceWrapper<2, 2, 4> wrapper;
|
||||
// Memory access to D
|
||||
EXPECT_FALSE(
|
||||
wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, InvalidStrides, Strides));
|
||||
EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
|
||||
// Memory access to E
|
||||
EXPECT_FALSE(
|
||||
wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, InvalidStrides));
|
||||
EXPECT_TRUE(wrapper.isSupported(Dims, Dims, Dims, Dims, Strides, Strides, Strides, Strides));
|
||||
}
|
||||
Reference in New Issue
Block a user