mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
Fix CI
This commit is contained in:
@@ -1496,8 +1496,10 @@ struct ElementWiseAdd
|
||||
* @note [return] Perform element-wise addition and store the result in 'r'
|
||||
*/
|
||||
template <typename ResT, typename ParamT>
|
||||
CK_TILE_DEVICE auto operator()(ResT& r, const ParamT& a, const ParamT& b, const ParamT& c) const
|
||||
-> void
|
||||
CK_TILE_DEVICE auto operator()(ResT& r,
|
||||
[[maybe_unused]] const ParamT& a,
|
||||
[[maybe_unused]] const ParamT& b,
|
||||
[[maybe_unused]] const ParamT& c) const -> void
|
||||
{
|
||||
r = a + b + c;
|
||||
}
|
||||
@@ -1536,7 +1538,7 @@ struct ElementWiseMul
|
||||
CK_TILE_DEVICE auto operator()(ResT& r, const ParamT& a, const ParamT& b, const ParamT& c) const
|
||||
-> void
|
||||
{
|
||||
r = a + b + c;
|
||||
r = a * b * c;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -154,7 +154,7 @@ struct CShuffleEpilogue
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
onst DsDramWindows& ds_dram_window,
|
||||
const DsDramWindows& ds_dram_window,
|
||||
void* p_smem)
|
||||
{
|
||||
|
||||
@@ -190,10 +190,6 @@ struct CShuffleEpilogue
|
||||
[&](auto idx) { return make_tile_window(ds_dram_window[idx], dram_tile_distribution); },
|
||||
number<NumDTensor>{});
|
||||
|
||||
using elemenet_wise_output_t =
|
||||
decltype(load_tile(make_tile_window(out_lds_window, dram_tile_distribution)));
|
||||
elemenet_wise_output_t elemenet_wise_output;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
@@ -215,26 +211,26 @@ struct CShuffleEpilogue
|
||||
store_tile(in_lds_window, c_warp_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
|
||||
const auto c_out_tensor =
|
||||
auto c_out_tensor =
|
||||
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(elemenet_wise_output, c_out_tensor),
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie(
|
||||
[&](auto i) -> const auto& { return ds_tensor[i]; }, number<NumDTensor>{}));
|
||||
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_in_out_unpack_tuple(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
store_tile(out_dram_window, elemenet_wise_output);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
update_tile(out_dram_window, elemenet_wise_output);
|
||||
}
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
|
||||
@@ -9,16 +9,9 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
|
||||
/// object. It contain all necessary information required to build proper kernel argument
|
||||
/// and launch kernel on GPU.
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GemmHostArgs
|
||||
{
|
||||
@@ -64,75 +57,23 @@ struct GemmHostArgs
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel device arguments.
|
||||
template <typename DType = tuple<>>
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
/// @brief The A input tensor's pointer to device memory.
|
||||
const void* a_ptr;
|
||||
/// @brief The B input tensor's pointer to device memory.
|
||||
const void* b_ptr;
|
||||
/// @brief The Ds input tensor's tuple to device memory.
|
||||
const DType ds_ptr;
|
||||
/// @brief The C output tensor's pointer to device memory.
|
||||
void* c_ptr;
|
||||
/// @brief GEMM's M dimension size.
|
||||
index_t M;
|
||||
/// @brief GEMM's N dimension size.
|
||||
index_t N;
|
||||
/// @brief GEMM's K dimension size.
|
||||
index_t K;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of A tensor.
|
||||
index_t stride_A;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of B tensor.
|
||||
index_t stride_B;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of Ds tensor.
|
||||
const index_t* stride_Ds;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of C tensor.
|
||||
index_t stride_C;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel template.
|
||||
///
|
||||
/// @paragraph Overview Overview
|
||||
/// This class provides the generic matrix multiplication kernel template. By semantic
|
||||
/// division of GEMM algorithm into following parts we achieve flexible, versatile
|
||||
/// and robust kernel implementation.
|
||||
///
|
||||
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
|
||||
/// function call operator" which determines the work scope of each workgroup.
|
||||
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
|
||||
/// This is the place where each workgroup is loading data from global memory and
|
||||
/// carrying out dot products.
|
||||
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
|
||||
/// responsible for storing results to global memory. This is also the place where
|
||||
/// any additional operator fusion may take place.
|
||||
///
|
||||
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
|
||||
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
|
||||
/// internal details of those functional parts. You can think of it like both gemm and
|
||||
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
|
||||
/// the policy is responsible for definition of all necessary data layouts and thread's
|
||||
/// work distribution.
|
||||
///
|
||||
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the
|
||||
/// output data tile to be calculated. It determines the workgroup to
|
||||
/// data relationship (or in other words - which data would be
|
||||
/// processed and calculated by which workgroup).
|
||||
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
|
||||
/// multiplication. This class should provide implementation of data
|
||||
/// loading from global memory and performing block-wise matrix
|
||||
/// multiplication. You can think of it as a work done by single
|
||||
/// workgroup point of view.
|
||||
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernel
|
||||
{
|
||||
@@ -580,10 +521,9 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view,
|
||||
b_tensor_view,
|
||||
generate_tuple(d_tensor_view, number<NumDTensor>{}),
|
||||
c_tensor_view);
|
||||
const auto& ds_tensor_view = generate_tuple(d_tensor_view, number<NumDTensor>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -740,7 +680,9 @@ struct GemmKernel
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const DsGridPointer ds_ptr,
|
||||
@@ -752,10 +694,8 @@ struct GemmKernel
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr,c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
@@ -774,9 +714,12 @@ struct GemmKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(d_block_window),
|
||||
DstInMemOp>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -795,7 +738,9 @@ struct GemmKernel
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
* @tparam DstInMemOp Destination memory operation (default: set).
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const DsGridPointer ds_ptr,
|
||||
@@ -808,10 +753,8 @@ struct GemmKernel
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
@@ -829,10 +772,12 @@ struct GemmKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window),
|
||||
decltype(c_block_tile),
|
||||
decltype(d_block_window),
|
||||
DstInMemOp>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmKernelArgs<DsGridPointer>& kargs) const
|
||||
@@ -849,7 +794,6 @@ struct GemmKernel
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
// const DsGridPointer* ds_ptr = reinterpret_cast<const DsGridPointer*>(kargs.ds_ptr);
|
||||
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
@@ -859,9 +803,7 @@ struct GemmKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
@@ -874,12 +816,27 @@ struct GemmKernel
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
@@ -891,6 +848,22 @@ struct GemmKernel
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
TYPED_TEST(TestCkTileMultipleDGemm, Basic)
|
||||
{
|
||||
constexpr int M = 256;
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 512;
|
||||
constexpr int M = 3840;
|
||||
constexpr int N = 4096;
|
||||
constexpr int K = 4096;
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
@@ -120,10 +120,13 @@ class TestCkTileMultipleDGemm : public ::testing::Test
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -152,7 +155,8 @@ class TestCkTileMultipleDGemm : public ::testing::Test
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -180,11 +184,29 @@ class TestCkTileMultipleDGemm : public ::testing::Test
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
|
||||
@@ -175,52 +175,51 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
|
||||
{
|
||||
// this->conv_params.clear();
|
||||
// this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
|
||||
// this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
|
||||
// this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
|
||||
// this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}});
|
||||
// this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}});
|
||||
// this->conv_params.push_back({1, 1, 1, 1, 1, {3}, {32}, {1}, {1}, {1}, {1}});
|
||||
// this->Run();
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
|
||||
this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
|
||||
this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}});
|
||||
this->conv_params.push_back({1, 1, 1, 1, 1, {3}, {32}, {1}, {1}, {1}, {1}});
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
// this->conv_params.push_back(
|
||||
// {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
// this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0,
|
||||
// 0}}); this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0},
|
||||
// {0, 0}}); this->conv_params.push_back(
|
||||
// {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
// this->conv_params.push_back(
|
||||
// {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
// this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1,
|
||||
// 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1},
|
||||
// {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1,
|
||||
// 1}, {1, 1}}); this->conv_params.push_back(
|
||||
// {2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
|
||||
{
|
||||
// this->conv_params.clear();
|
||||
// this->conv_params.push_back(
|
||||
// {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
// this->conv_params.push_back(
|
||||
// {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
// this->conv_params.push_back(
|
||||
// {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
// this->conv_params.push_back(
|
||||
// {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
// this->conv_params.push_back(
|
||||
// {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
// this->conv_params.push_back(
|
||||
// {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
// this->conv_params.push_back(
|
||||
// {3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
// this->Run();
|
||||
}
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->Run();
|
||||
}
|
||||
Reference in New Issue
Block a user