mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][Test] Moving device_op creation before data initialization.
Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>
This commit is contained in:
committed by
Michał Kulikowski
parent
12599a6802
commit
a3feb9c1df
@@ -140,6 +140,25 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
std::cout << "wei: " << wei_host.mDesc << std::endl;
|
||||
std::cout << "out: " << out.mDesc << std::endl;
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
ck::Tuple<WeiLayout>,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::Tuple<WeiDataType>,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
d.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
@@ -179,23 +198,6 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
|
||||
RunReference(conv_param, in, wei_host, out, d);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
ck::Tuple<WeiLayout>,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ck::Tuple<WeiDataType>,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
int num_kernel = 0;
|
||||
|
||||
for(std::size_t i = 0; i < op_ptrs.size(); ++i)
|
||||
|
||||
@@ -104,6 +104,27 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification,
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << host_output.mDesc << std::endl;
|
||||
|
||||
// InDataType and WeiDataType must be tuple, inLayout and weiLayout are single.
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
ck::Tuple<InDataType, InDataType>,
|
||||
ck::Tuple<WeiDataType, WeiDataType>,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
@@ -246,27 +267,6 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
// InDataType and WeiDataType must be tuple, inLayout and weiLayout are single.
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
ck::Tuple<InDataType, InDataType>,
|
||||
ck::Tuple<WeiDataType, WeiDataType>,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::array<const void*, NumAs> as{in_device_buf.GetDeviceBuffer(),
|
||||
in_bias_device_buf.GetDeviceBuffer()};
|
||||
std::array<const void*, NumBs> bs{wei_device_buf.GetDeviceBuffer(),
|
||||
|
||||
Reference in New Issue
Block a user