mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
[CK][CK Tile] Add grouped conv backward weight tile test and fix tr load in BASE_V1 pipeline (#5115)
## Motivation Test grouped conv backward weight from ck tile and fix incorrect values. ## Technical Details - Add test for CI - Add daily tests - Fix transpose load in BASE_V1 pipeline ## Test Plan test_grouped_convnd_backward_weight_tile ## Test Result in progress ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-783
This commit is contained in:
3
Jenkinsfile
vendored
3
Jenkinsfile
vendored
@@ -1524,7 +1524,8 @@ pipeline {
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ python3 ../experimental/grouped_convolution_tile_instances/generate_instances.py --mode=profiler && \
|
||||
cmake .. --preset dev-gfx90a -D CK_EXPERIMENTAL_BUILDER=ON && \
|
||||
make -j64 test_grouped_convnd_fwd_tile && \
|
||||
make -j64 test_grouped_convnd_fwd_tile test_grouped_convnd_bwd_weight_tile && \
|
||||
./bin/test_grouped_convnd_bwd_weight_tile && \
|
||||
./bin/test_grouped_convnd_fwd_tile"""
|
||||
}
|
||||
steps{
|
||||
|
||||
@@ -288,6 +288,8 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
end = instance.rindex('>')
|
||||
params_str = instance[start:end]
|
||||
args = parse_instance_string(params_str)
|
||||
|
||||
direct_load = False
|
||||
|
||||
is_v3_instance = instance.find("Xdl_CShuffleV3") != -1
|
||||
is_two_stage_instance = instance.find("TwoStage") != -1
|
||||
@@ -357,6 +359,7 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
if len(args) != 45:
|
||||
raise RuntimeError(f"Wrong number of parameters in the V3 XDL CShuffle instance string: {instance}")
|
||||
|
||||
direct_load = int(args[43]) == 1
|
||||
num_groups_to_merge = int(args[44])
|
||||
|
||||
# Block GEMM pipeline parameters
|
||||
@@ -396,6 +399,16 @@ def parse_bwd_weight_instances(instances, problem_name):
|
||||
if pipeline_version == "V5":
|
||||
pipeline_version = "V6"
|
||||
|
||||
if direct_load:
|
||||
if pipeline_version == "V1":
|
||||
pipeline_version = "ASYNC_V1"
|
||||
elif pipeline_version == "V4":
|
||||
pipeline_version = "ASYNC_V4"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not supported pipeline for direct load: pipeline_version={pipeline_version} in instance: {instance}"
|
||||
)
|
||||
|
||||
m_warp = int(m_per_block / (m_per_xdl * m_xdl_per_wave))
|
||||
n_warp = int(n_per_block / (n_per_xdl * n_xdl_per_wave))
|
||||
warp_size = 64
|
||||
|
||||
@@ -112,6 +112,9 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
@@ -272,29 +275,27 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,7 +311,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
@@ -322,29 +324,27 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
@@ -354,7 +354,8 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
@@ -479,29 +480,27 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -518,36 +517,38 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile,
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window,
|
||||
is_a_load_tr_v,
|
||||
is_b_load_tr_v);
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res);
|
||||
store_tile(a_copy_lds_window, a_shuffle_tmp_loop);
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res);
|
||||
store_tile(b_copy_lds_window, b_shuffle_tmp_loop);
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
@@ -558,7 +559,11 @@ struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Pr
|
||||
{
|
||||
block_sync_lds();
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile,
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window,
|
||||
is_a_load_tr_v,
|
||||
is_b_load_tr_v);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
|
||||
@@ -18,6 +18,17 @@ elseif(DL_KERNELS)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_tile test_grouped_convnd_bwd_weight_tile.cpp)
|
||||
target_compile_options(test_grouped_convnd_bwd_weight_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_tile PRIVATE gtest_main getopt::getopt utility)
|
||||
if(TARGET device_grouped_conv_bwd_weight_tile_instances)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_tile PRIVATE device_grouped_conv_bwd_weight_tile_instances)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface_xdl PRIVATE utility)
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "profiler/grouped_convolution_backward_weight_tile_algs.hpp"
|
||||
|
||||
static ck::index_t args_mask = 0xffff;
|
||||
static ck::index_t instance_index = -1;
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace ckp = ck_tile::builder::profiling;
|
||||
|
||||
template <ck_tile::index_t num_spatial_dim_,
|
||||
ckb::DataType data_type_,
|
||||
ckb::DataType acc_data_type_,
|
||||
ckb::TensorLayout in_layout_,
|
||||
ckb::TensorLayout wei_layout_,
|
||||
ckb::TensorLayout out_layout_>
|
||||
struct SignatureDetails
|
||||
{
|
||||
static constexpr ck_tile::index_t num_spatial_dim = num_spatial_dim_;
|
||||
static constexpr ckb::DataType data_type = data_type_;
|
||||
static constexpr ckb::DataType acc_data_type = acc_data_type_;
|
||||
static constexpr ckb::TensorLayout in_layout = in_layout_;
|
||||
static constexpr ckb::TensorLayout wei_layout = wei_layout_;
|
||||
static constexpr ckb::TensorLayout out_layout = out_layout_;
|
||||
};
|
||||
|
||||
template <typename SignatureDetailsType>
|
||||
class TestGroupedConvndBwdWeightTile : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
static constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = SignatureDetailsType::num_spatial_dim,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = SignatureDetailsType::data_type,
|
||||
.accumulation_data_type = SignatureDetailsType::acc_data_type,
|
||||
.input = {.config = {.layout = SignatureDetailsType::in_layout}},
|
||||
.weight = {.config = {.layout = SignatureDetailsType::wei_layout}},
|
||||
.output = {.config = {.layout = SignatureDetailsType::out_layout}}};
|
||||
|
||||
std::vector<ckt::Args<SIGNATURE>> conv_args;
|
||||
std::vector<std::string> split_ks{"-1", "1", "2"};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
ASSERT_FALSE(conv_args.empty());
|
||||
bool pass = true;
|
||||
for(size_t i = 0; i < conv_args.size(); i++)
|
||||
{
|
||||
for(auto& split_k : split_ks)
|
||||
{
|
||||
if((args_mask & (1 << i)) == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
auto& args = conv_args[i];
|
||||
|
||||
auto inputs = alloc_inputs(args);
|
||||
auto outputs = alloc_outputs(args);
|
||||
ckt::init_tensor_buffer_uniform_int(
|
||||
inputs.get().input, args.make_input_descriptor(), -5, 5);
|
||||
ckt::init_tensor_buffer_uniform_int(
|
||||
inputs.get().output, args.make_output_descriptor(), -5, 5);
|
||||
|
||||
std::cout << args.make_input_descriptor() << std::endl;
|
||||
std::cout << args.make_weight_descriptor() << std::endl;
|
||||
std::cout << args.make_output_descriptor() << std::endl;
|
||||
[[maybe_unused]] auto&& [case_passed, avg_time, op_name, best_split_k] =
|
||||
|
||||
ckp::run_grouped_conv_backward_weight_tile_algs(
|
||||
args,
|
||||
split_k,
|
||||
inputs.get(),
|
||||
outputs.get(),
|
||||
ck_tile::stream_config{nullptr, false /*time_kernel*/});
|
||||
|
||||
pass = pass && case_passed;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
void conv_args_append(std::size_t,
|
||||
std::size_t G,
|
||||
std::size_t N,
|
||||
std::size_t K,
|
||||
std::size_t C,
|
||||
const std::vector<std::size_t>& filter_spatial_lengths,
|
||||
const std::vector<std::size_t>& input_spatial_lengths,
|
||||
const std::vector<std::size_t>& conv_filter_strides,
|
||||
const std::vector<std::size_t>& conv_filter_dilations,
|
||||
const std::vector<std::size_t>& input_left_pads,
|
||||
const std::vector<std::size_t>& input_right_pads)
|
||||
{
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths =
|
||||
{
|
||||
.batch_size = N,
|
||||
.groups = G,
|
||||
.input_channels = C,
|
||||
.output_channels = K,
|
||||
.image = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
input_spatial_lengths),
|
||||
.filter = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
filter_spatial_lengths),
|
||||
},
|
||||
.filter_strides = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
conv_filter_strides),
|
||||
.filter_dilation =
|
||||
ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
conv_filter_dilations),
|
||||
.input_left_pad = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
input_left_pads),
|
||||
.input_right_pad =
|
||||
ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
input_right_pads),
|
||||
.a_elementwise_op = {},
|
||||
.b_elementwise_op = {},
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
conv_args.push_back(args);
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes2d = ::testing::Types<SignatureDetails<2,
|
||||
ckb::DataType::FP32,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NHWGC,
|
||||
ckb::TensorLayout::GKYXC,
|
||||
ckb::TensorLayout::NHWGK>,
|
||||
SignatureDetails<2,
|
||||
ckb::DataType::FP16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NHWGC,
|
||||
ckb::TensorLayout::GKYXC,
|
||||
ckb::TensorLayout::NHWGK>,
|
||||
SignatureDetails<2,
|
||||
ckb::DataType::BF16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NHWGC,
|
||||
ckb::TensorLayout::GKYXC,
|
||||
ckb::TensorLayout::NHWGK>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<SignatureDetails<3,
|
||||
ckb::DataType::FP32,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NDHWGC,
|
||||
ckb::TensorLayout::GKZYXC,
|
||||
ckb::TensorLayout::NDHWGK>,
|
||||
SignatureDetails<3,
|
||||
ckb::DataType::FP16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NDHWGC,
|
||||
ckb::TensorLayout::GKZYXC,
|
||||
ckb::TensorLayout::NDHWGK>,
|
||||
SignatureDetails<3,
|
||||
ckb::DataType::BF16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NDHWGC,
|
||||
ckb::TensorLayout::GKZYXC,
|
||||
ckb::TensorLayout::NDHWGK>>;
|
||||
|
||||
template <typename SignatureDetailsType>
|
||||
class TestGroupedConvndBwdWeightTile2d : public TestGroupedConvndBwdWeightTile<SignatureDetailsType>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename SignatureDetailsType>
|
||||
class TestGroupedConvndBwdWeightTile3d : public TestGroupedConvndBwdWeightTile<SignatureDetailsType>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightTile2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightTile3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightTile2d, Test2D)
|
||||
{
|
||||
this->conv_args.clear();
|
||||
this->conv_args_append(2, 2, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->conv_args_append(2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->conv_args_append(2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->conv_args_append(2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->conv_args_append(2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightTile3d, Test3D)
|
||||
{
|
||||
this->conv_args.clear();
|
||||
this->conv_args_append(
|
||||
3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0});
|
||||
this->conv_args_append(
|
||||
3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
this->conv_args_append(
|
||||
3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0});
|
||||
this->conv_args_append(
|
||||
3, 1, 1, 1, 32, {3, 3, 3}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
this->conv_args_append(
|
||||
3, 1, 1, 64, 3, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
this->conv_args_append(
|
||||
3, 1, 1, 1, 1, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
this->conv_args_append(
|
||||
3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
this->template Run<3>();
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
args_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: args_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
Reference in New Issue
Block a user