mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Add contraction profiler and tests (#701)
* Add contraction profiler and tests
* Build and style fixes
* Allow to use any elementwise operator for ref_contraction
* Introduce profile_contraction_scale and profile_contraction_bilinear
* Make ref_contraction generic and extend interface tests
* Stylistic minor fixes
* Extend test_contraction_interface
[ROCm/composable_kernel commit: 642d5e9155]
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -385,22 +251,22 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -385,22 +251,22 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -368,22 +234,23 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
@@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device::
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKN;
|
||||
|
||||
// hardcoded for NumDimM == NumDimN == NumDimK == 2
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
|
||||
struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_ms_ks_{a_ms_ks},
|
||||
b_ns_ks_{b_ns_ks},
|
||||
e_ms_ns_{e_ms_ns},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ADataType>& a_ms_ks_;
|
||||
const Tensor<BDataType>& b_ns_ks_;
|
||||
Tensor<EDataType>& e_ms_ns_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public ck::tensor_operation::device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceContraction_M2_N2_K2::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) {
|
||||
const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2];
|
||||
const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3];
|
||||
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k0 = 0; k0 < K0; ++k0)
|
||||
{
|
||||
for(int k1 = 0; k1 < K1; ++k1)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
arg.a_element_op_(
|
||||
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1)));
|
||||
arg.b_element_op_(
|
||||
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
}
|
||||
|
||||
AccDataType v_c;
|
||||
|
||||
arg.cde_element_op_(v_c, v_acc);
|
||||
|
||||
arg.e_ms_ns_(m0, m1, n0, n1) = v_c;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_ms_ns,
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[0],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[1],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[2],
|
||||
arg.e_ms_ns_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const ck::tensor_operation::device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static auto MakeArgument(const Tensor<ADataType>& a_ms_ks,
|
||||
const Tensor<BDataType>& b_ns_ks,
|
||||
Tensor<EDataType>& e_ms_ns,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<ck::tensor_operation::device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceContraction_M2_N2_K2"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -368,22 +234,23 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{});
|
||||
Tensor<float> empty_tensor(std::vector<ck::index_t>{}, std::vector<ck::index_t>{});
|
||||
auto ref_argument =
|
||||
ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user