mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK tests] Extend conv GPU reference (#3539)
* test_convnd_fwd
* test_convnd_bwd_data
* test_conv_bwd_data_scale
* test_grouped_convnd_fwd_clamp
* test_grouped_convnd_fwd_scale
* multiple A/B tensors and D tensor for fwd GPU ref
* test_grouped_convnd_fwd_scaleadd_ab
* test_grouped_convnd_fwd_bias_clamp
* test_grouped_convnd_fwd_bilinear
* test_grouped_convnd_fwd_gk_bias_clamp
* Extend GPU reference to enable batchnorm epilogue
* test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp
* test_grouped_conv_bwd_data_bilinear
* test_grouped_convnd_bwd_weight_bilinear
* Add missing template instantiation
* Perform operations in float in reference
* Slightly increase tolerance for batchnorm profiler
* Revert "Slightly increase tolerance for batchnorm profiler"
This reverts commit a3b2475229.
* Revert "test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp"
This reverts commit 6da4576060.
* Revert "Extend GPU reference to enable batchnorm epilogue"
This reverts commit e2f75fa10e.
* Clarify variable names
* Refactor elementwise ops into helper functions
* Make helpers C++17-compatible
[ROCm/composable_kernel commit: c190d8d61f]
This commit is contained in:
@@ -4,6 +4,9 @@
|
||||
add_gtest_executable(test_gpu_reference_conv_fwd test_gpu_reference_conv_fwd.cpp)
|
||||
target_link_libraries(test_gpu_reference_conv_fwd PRIVATE utility)
|
||||
|
||||
add_gtest_executable(test_gpu_reference_conv_fwd_multi_abd test_gpu_reference_conv_fwd_multi_abd.cpp)
|
||||
target_link_libraries(test_gpu_reference_conv_fwd_multi_abd PRIVATE utility)
|
||||
|
||||
add_gtest_executable(test_gpu_reference_conv_bwd_data test_gpu_reference_conv_bwd_data.cpp)
|
||||
target_link_libraries(test_gpu_reference_conv_bwd_data PRIVATE utility)
|
||||
|
||||
|
||||
@@ -381,5 +381,230 @@ bool test_conv_gpu_ref(const ck::utils::conv::ConvParam& params, ConvKernelType
|
||||
}
|
||||
}
|
||||
|
||||
// Forward convolution with D tensor support
|
||||
template <index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename OutElementOp>
|
||||
bool test_conv_fwd_with_d_tensor_impl(const ck::utils::conv::ConvParam& params,
|
||||
const Tensor<InDataType>& input_cpu,
|
||||
const Tensor<WeiDataType>& weight_cpu,
|
||||
const Tensor<OutDataType>& d_cpu,
|
||||
DeviceMem& input_dev,
|
||||
DeviceMem& weight_dev,
|
||||
DeviceMem& d_dev,
|
||||
DeviceMem& output_dev,
|
||||
OutElementOp out_element_op)
|
||||
{
|
||||
using InElementOp = tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Create D tensor lengths and strides for GPU reference
|
||||
std::vector<index_t> d_lengths_vec(NDimSpatial + 3);
|
||||
d_lengths_vec[0] = params.G_;
|
||||
d_lengths_vec[1] = params.N_;
|
||||
d_lengths_vec[2] = params.K_;
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
d_lengths_vec[3 + i] = static_cast<index_t>(params.output_spatial_lengths_[i]);
|
||||
}
|
||||
|
||||
std::vector<index_t> d_strides_vec =
|
||||
ref::compute_conv_tensor_strides<OutLayout>(d_lengths_vec, params.num_dim_spatial_);
|
||||
|
||||
std::array<const OutDataType*, 1> d_ptrs = {
|
||||
reinterpret_cast<const OutDataType*>(d_dev.GetDeviceBuffer())};
|
||||
std::array<std::vector<index_t>, 1> d_lengths = {d_lengths_vec};
|
||||
std::array<std::vector<index_t>, 1> d_strides = {d_strides_vec};
|
||||
|
||||
// Call GPU reference with D tensor
|
||||
std::array<const InDataType*, 1> in_ptrs = {
|
||||
reinterpret_cast<const InDataType*>(input_dev.GetDeviceBuffer())};
|
||||
std::array<const WeiDataType*, 1> wei_ptrs = {
|
||||
reinterpret_cast<const WeiDataType*>(weight_dev.GetDeviceBuffer())};
|
||||
|
||||
ref::naive_conv_fwd_multi_abd<0,
|
||||
0,
|
||||
1,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
OutDataType>( // Explicitly specify TD = OutDataType
|
||||
in_ptrs,
|
||||
wei_ptrs,
|
||||
d_ptrs,
|
||||
reinterpret_cast<OutDataType*>(output_dev.GetDeviceBuffer()),
|
||||
params,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
out_element_op);
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
// Run CPU reference
|
||||
std::vector<long_index_t> strides_long(params.conv_filter_strides_.begin(),
|
||||
params.conv_filter_strides_.end());
|
||||
std::vector<long_index_t> dilations_long(params.conv_filter_dilations_.begin(),
|
||||
params.conv_filter_dilations_.end());
|
||||
std::vector<long_index_t> pads_long(params.input_left_pads_.begin(),
|
||||
params.input_left_pads_.end());
|
||||
|
||||
Tensor<InDataType> input_ref = input_cpu;
|
||||
Tensor<WeiDataType> weight_ref = weight_cpu;
|
||||
Tensor<OutDataType> output_ref(
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params));
|
||||
|
||||
std::array<Tensor<OutDataType>, 1> d_tensors_ref = {d_cpu};
|
||||
|
||||
auto ref_conv = tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
0, // NumA
|
||||
0, // NumB
|
||||
1 // NumD
|
||||
>();
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_arg = ref_conv.MakeArgument(input_ref,
|
||||
weight_ref,
|
||||
output_ref,
|
||||
strides_long,
|
||||
dilations_long,
|
||||
pads_long,
|
||||
pads_long,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
out_element_op,
|
||||
{}, // A tensors
|
||||
{}, // B tensors
|
||||
d_tensors_ref);
|
||||
ref_invoker.Run(ref_arg);
|
||||
|
||||
// Copy result from device and compare
|
||||
Tensor<OutDataType> output_gpu(output_ref.mDesc);
|
||||
output_dev.FromDevice(output_gpu.mData.data());
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
// Compare results
|
||||
return ck::utils::check_err(output_gpu, output_ref);
|
||||
}
|
||||
|
||||
// Forward convolution with multiple A/B tensor support
|
||||
template <index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp>
|
||||
bool test_conv_fwd_with_multi_ab_impl(const ck::utils::conv::ConvParam& params,
|
||||
const Tensor<InDataType>& input_cpu,
|
||||
const Tensor<WeiDataType>& weight_cpu,
|
||||
const Tensor<InDataType>& a_extra_cpu,
|
||||
const Tensor<WeiDataType>& b_extra_cpu,
|
||||
DeviceMem& input_dev,
|
||||
DeviceMem& weight_dev,
|
||||
DeviceMem& a_extra_dev,
|
||||
DeviceMem& b_extra_dev,
|
||||
DeviceMem& output_dev,
|
||||
InElementOp in_element_op,
|
||||
WeiElementOp wei_element_op)
|
||||
{
|
||||
using OutElementOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// Call GPU reference with extra A and B tensors
|
||||
std::array<const InDataType*, 2> in_ptrs = {
|
||||
reinterpret_cast<const InDataType*>(input_dev.GetDeviceBuffer()),
|
||||
reinterpret_cast<const InDataType*>(a_extra_dev.GetDeviceBuffer())};
|
||||
std::array<const WeiDataType*, 2> wei_ptrs = {
|
||||
reinterpret_cast<const WeiDataType*>(weight_dev.GetDeviceBuffer()),
|
||||
reinterpret_cast<const WeiDataType*>(b_extra_dev.GetDeviceBuffer())};
|
||||
std::array<const OutDataType*, 0> d_ptrs = {};
|
||||
std::array<std::vector<index_t>, 0> d_lengths = {};
|
||||
std::array<std::vector<index_t>, 0> d_strides = {};
|
||||
|
||||
ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>(
|
||||
in_ptrs,
|
||||
wei_ptrs,
|
||||
d_ptrs,
|
||||
reinterpret_cast<OutDataType*>(output_dev.GetDeviceBuffer()),
|
||||
params,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
OutElementOp{});
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
// Run CPU reference
|
||||
std::vector<long_index_t> strides_long(params.conv_filter_strides_.begin(),
|
||||
params.conv_filter_strides_.end());
|
||||
std::vector<long_index_t> dilations_long(params.conv_filter_dilations_.begin(),
|
||||
params.conv_filter_dilations_.end());
|
||||
std::vector<long_index_t> pads_long(params.input_left_pads_.begin(),
|
||||
params.input_left_pads_.end());
|
||||
|
||||
Tensor<InDataType> input_ref = input_cpu;
|
||||
Tensor<WeiDataType> weight_ref = weight_cpu;
|
||||
Tensor<OutDataType> output_ref(
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params));
|
||||
|
||||
std::array<Tensor<InDataType>, 1> a_tensors_ref = {a_extra_cpu};
|
||||
std::array<Tensor<WeiDataType>, 1> b_tensors_ref = {b_extra_cpu};
|
||||
|
||||
auto ref_conv = tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
1, // NumA
|
||||
1, // NumB
|
||||
0 // NumD
|
||||
>();
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_arg = ref_conv.MakeArgument(input_ref,
|
||||
weight_ref,
|
||||
output_ref,
|
||||
strides_long,
|
||||
dilations_long,
|
||||
pads_long,
|
||||
pads_long,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
OutElementOp{},
|
||||
a_tensors_ref,
|
||||
b_tensors_ref,
|
||||
{});
|
||||
ref_invoker.Run(ref_arg);
|
||||
|
||||
// Copy result from device and compare
|
||||
Tensor<OutDataType> output_gpu(output_ref.mDesc);
|
||||
output_dev.FromDevice(output_gpu.mData.data());
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
// Compare results
|
||||
return ck::utils::check_err(output_gpu, output_ref);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
|
||||
319
test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp
Normal file
319
test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp
Normal file
@@ -0,0 +1,319 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "gpu_reference_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
|
||||
using namespace ck;
|
||||
using ck::test::ConvKernelType;
|
||||
|
||||
// ==================== D Tensor (Bias) Tests ====================
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool test_conv_gpu_ref_with_bias(const ck::utils::conv::ConvParam& params)
|
||||
{
|
||||
using tensor_operation::element_wise::AddClamp;
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(params);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(params);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params);
|
||||
|
||||
// Create tensors
|
||||
Tensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
Tensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
Tensor<OutDataType> bias(out_g_n_k_wos_desc); // Same shape as output
|
||||
|
||||
// Allocate device memory
|
||||
DeviceMem input_dev(input.mData.size() * sizeof(InDataType));
|
||||
DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType));
|
||||
DeviceMem bias_dev(bias.mData.size() * sizeof(OutDataType));
|
||||
DeviceMem output_dev(output.mData.size() * sizeof(OutDataType));
|
||||
|
||||
// Initialize and copy tensors
|
||||
test::initialize_and_copy_tensor(input, input_dev);
|
||||
test::initialize_and_copy_tensor(weight, weight_dev);
|
||||
test::initialize_and_copy_tensor(bias, bias_dev);
|
||||
|
||||
// Test with AddClamp (bias operation with clamping)
|
||||
AddClamp out_element_op(0.0f, 6.0f); // Clamp between 0 and 6
|
||||
|
||||
return test::test_conv_fwd_with_d_tensor_impl<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
params, input, weight, bias, input_dev, weight_dev, bias_dev, output_dev, out_element_op);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bias)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_small();
|
||||
bool result = test_conv_gpu_ref_with_bias<2,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bias)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_medium();
|
||||
bool result = test_conv_gpu_ref_with_bias<2,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv3DFP32Bias)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_3d_small();
|
||||
bool result = test_conv_gpu_ref_with_bias<3,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
tensor_layout::convolution::GNCDHW,
|
||||
tensor_layout::convolution::GKCZYX,
|
||||
tensor_layout::convolution::GNKDHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bias)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_grouped_g2();
|
||||
bool result = test_conv_gpu_ref_with_bias<2,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32GroupedG4Bias)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_grouped_g4();
|
||||
bool result = test_conv_gpu_ref_with_bias<2,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
// ==================== D Tensor (Bilinear) Tests ====================
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool test_conv_gpu_ref_with_bilinear(const ck::utils::conv::ConvParam& params)
|
||||
{
|
||||
using tensor_operation::element_wise::Bilinear;
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(params);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(params);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params);
|
||||
|
||||
// Create tensors
|
||||
Tensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
Tensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
Tensor<OutDataType> d_tensor(out_g_n_k_wos_desc); // Same shape as output
|
||||
|
||||
// Allocate device memory
|
||||
DeviceMem input_dev(input.mData.size() * sizeof(InDataType));
|
||||
DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType));
|
||||
DeviceMem d_dev(d_tensor.mData.size() * sizeof(OutDataType));
|
||||
DeviceMem output_dev(output.mData.size() * sizeof(OutDataType));
|
||||
|
||||
// Initialize and copy tensors
|
||||
test::initialize_and_copy_tensor(input, input_dev);
|
||||
test::initialize_and_copy_tensor(weight, weight_dev);
|
||||
test::initialize_and_copy_tensor(d_tensor, d_dev);
|
||||
|
||||
// Test with Bilinear: y = alpha * conv_result + beta * d_tensor
|
||||
Bilinear out_element_op(1.5f, 0.5f); // alpha=1.5, beta=0.5
|
||||
|
||||
return test::test_conv_fwd_with_d_tensor_impl<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
params, input, weight, d_tensor, input_dev, weight_dev, d_dev, output_dev, out_element_op);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bilinear)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_small();
|
||||
bool result = test_conv_gpu_ref_with_bilinear<2,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bilinear)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_medium();
|
||||
bool result = test_conv_gpu_ref_with_bilinear<2,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bilinear)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_grouped_g2();
|
||||
bool result = test_conv_gpu_ref_with_bilinear<2,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
// ==================== Multiple A/B (ScaleAdd) Tests ====================
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool test_conv_gpu_ref_with_scaleadd(const ck::utils::conv::ConvParam& params)
|
||||
{
|
||||
using tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(params);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(params);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params);
|
||||
|
||||
// Create tensors
|
||||
Tensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
Tensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
Tensor<InDataType> a_extra(in_g_n_c_wis_desc); // Extra A tensor (same shape as input)
|
||||
Tensor<WeiDataType> b_extra(wei_g_k_c_xs_desc); // Extra B tensor (same shape as weight)
|
||||
|
||||
// Allocate device memory
|
||||
DeviceMem input_dev(input.mData.size() * sizeof(InDataType));
|
||||
DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType));
|
||||
DeviceMem a_extra_dev(a_extra.mData.size() * sizeof(InDataType));
|
||||
DeviceMem b_extra_dev(b_extra.mData.size() * sizeof(WeiDataType));
|
||||
DeviceMem output_dev(output.mData.size() * sizeof(OutDataType));
|
||||
|
||||
// Initialize and copy tensors
|
||||
test::initialize_and_copy_tensor(input, input_dev);
|
||||
test::initialize_and_copy_tensor(weight, weight_dev);
|
||||
test::initialize_and_copy_tensor(a_extra, a_extra_dev);
|
||||
test::initialize_and_copy_tensor(b_extra, b_extra_dev);
|
||||
|
||||
// Test with ScaleAdd: in_out = scale * in_0 + in_1, wei_out = scale * wei_0 + wei_1
|
||||
ScaleAdd in_element_op(2.0f); // scale factor for input
|
||||
ScaleAdd wei_element_op(1.5f); // scale factor for weight
|
||||
|
||||
return test::test_conv_fwd_with_multi_ab_impl<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(params,
|
||||
input,
|
||||
weight,
|
||||
a_extra,
|
||||
b_extra,
|
||||
input_dev,
|
||||
weight_dev,
|
||||
a_extra_dev,
|
||||
b_extra_dev,
|
||||
output_dev,
|
||||
in_element_op,
|
||||
wei_element_op);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16ScaleAdd)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_small();
|
||||
bool result = test_conv_gpu_ref_with_scaleadd<2,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32ScaleAdd)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_medium();
|
||||
bool result = test_conv_gpu_ref_with_scaleadd<2,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2ScaleAdd)
|
||||
{
|
||||
auto params = test::conv_test_shapes::get_2d_grouped_g2();
|
||||
bool result = test_conv_gpu_ref_with_scaleadd<2,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
tensor_layout::convolution::GNCHW,
|
||||
tensor_layout::convolution::GKCYX,
|
||||
tensor_layout::convolution::GNKHW>(params);
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
Reference in New Issue
Block a user