mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Add grouped conv bwd weight dl instances and new layout (#897)
* Add grouped conv bwd weight dl instances and new layout
* Add M and N padding
* Remove todo comment
* Enable grouped conv fwd dl k,c=1 generic instance
* Comment fixes
[ROCm/composable_kernel commit: 475188ca2e]
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
{
|
||||
@@ -27,28 +29,59 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
using NDimSpatial = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
ck::index_t split_k{2};
|
||||
std::vector<ck::index_t> split_ks{1, 2};
|
||||
|
||||
bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k)
|
||||
{
|
||||
// K or C are odd is supported only by DL kernel (only applies to fp16)
|
||||
// DL kernel is only supported for split_k=1
|
||||
if constexpr(std::is_same_v<InDataType, ck::half_t>)
|
||||
{
|
||||
if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// 1d NWGC is only supported by DL kernel
|
||||
// DL kernel is only supported for split_k=1
|
||||
if constexpr(std::is_same_v<InLayout, NWGC> && std::is_same_v<OutLayout, NWGK>)
|
||||
{
|
||||
if(split_k != 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
|
||||
for(auto& param : conv_params)
|
||||
for(auto split_k : split_ks)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{},
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param,
|
||||
split_k);
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
if(!skip_case(param, split_k))
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{},
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param,
|
||||
split_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -69,12 +102,13 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes1d = ::testing::Types<
|
||||
std::tuple<float, float, float, GNWC, GKXC, GNWK, ck::Number<1>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK, ck::Number<1>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNWC, GKXC, GNWK, ck::Number<1>>>;
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNWC, GKXC, GNWK, ck::Number<1>>,
|
||||
std::tuple<float, float, float, NWGC, GKXC, NWGK, ck::Number<1>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, NWGC, GKXC, NWGK, ck::Number<1>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NWGC, GKXC, NWGK, ck::Number<1>>>;
|
||||
using KernelTypes2d = ::testing::Types<
|
||||
std::tuple<float, float, float, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
|
||||
|
||||
Reference in New Issue
Block a user