[CK_TILE] Switch into universal gemms for conv bwds (#2981)

* switch into universal gemms for conv bwds

* some fixes and support universal gemm in conv fwd

* add reviewer comments
This commit is contained in:
jakpiase
2025-10-14 16:09:16 +02:00
committed by GitHub
parent 589e242eda
commit 6deaaa92cc
19 changed files with 1043 additions and 550 deletions

View File

@@ -44,13 +44,13 @@ struct GroupedConvBwdDataKernelArgs
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -145,15 +145,15 @@ struct GroupedConvBwdDataKernelArgs
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -161,13 +161,13 @@ struct GroupedConvBwdDataKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[1])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1])};
static_cast<index_t>(args.conv_filter_strides_[1])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1])};
static_cast<index_t>(args.input_left_pads_[1])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1])};
static_cast<index_t>(args.input_right_pads_[1])};
k_batch = args.k_batch;
@@ -262,17 +262,17 @@ struct GroupedConvBwdDataKernelArgs
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -281,17 +281,17 @@ struct GroupedConvBwdDataKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[2])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1]),
static_cast<index_t>(args.conv_filter_dilations_[2])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
k_batch = args.k_batch;
@@ -387,8 +387,8 @@ struct GroupedConvBwdDataKernelArgs
static constexpr index_t MaxGroupedGemmGroupsNum = 128;
using ABCGridDescs = remove_cvref_t<
decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>;
using ABCGridDescs = remove_cvref_t<decltype(
ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>;
using AGridDescMK = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
using BGridDescNK = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
@@ -471,10 +471,6 @@ template <typename GroupedConvTraitsType_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardDataKernel
{
// Todo: Enable Vector Load Size > 1
static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
GroupedConvTraitsType_::VectorSizeB == 1);
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
@@ -517,12 +513,7 @@ struct GroupedConvolutionBackwardDataKernel
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
"Not supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
// TODO: Change to and enable vector load
// static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>,
// "Not supported A GEMM layout!");
// static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>,
// "Not supported B GEMM layout!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
"Not supported C GEMM layout!");
@@ -742,8 +733,8 @@ struct GroupedConvolutionBackwardDataKernel
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I1);
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
}();
@@ -788,9 +779,9 @@ struct GroupedConvolutionBackwardDataKernel
const auto& b_block_window = [&]() {
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, i_k});
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_k, i_n});
}();
const auto ds_block_window = generate_tuple(

View File

@@ -40,13 +40,13 @@ struct GroupedConvBwdWeightKernelArgs
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -80,8 +80,8 @@ struct GroupedConvBwdWeightKernelArgs
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
GroupedConvTraitsType_::NDimSpatial>();
a_grid_desc_m_k = grid_descs.at(number<0>{});
b_grid_desc_n_k = grid_descs.at(number<1>{});
a_grid_desc_k_m = grid_descs.at(number<0>{});
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NWGK
@@ -92,9 +92,9 @@ struct GroupedConvBwdWeightKernelArgs
1,
std::multiplies<index_t>());
GemmM = a_grid_desc_m_k.get_length(number<0>{});
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
}
@@ -109,15 +109,15 @@ struct GroupedConvBwdWeightKernelArgs
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -125,13 +125,13 @@ struct GroupedConvBwdWeightKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[1])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1])};
static_cast<index_t>(args.conv_filter_strides_[1])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1])};
static_cast<index_t>(args.input_left_pads_[1])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1])};
static_cast<index_t>(args.input_right_pads_[1])};
k_batch = args.k_batch;
@@ -156,8 +156,8 @@ struct GroupedConvBwdWeightKernelArgs
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
GroupedConvTraitsType_::NDimSpatial>();
a_grid_desc_m_k = grid_descs.at(number<0>{});
b_grid_desc_n_k = grid_descs.at(number<1>{});
a_grid_desc_k_m = grid_descs.at(number<0>{});
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NHWGK
@@ -168,9 +168,9 @@ struct GroupedConvBwdWeightKernelArgs
1,
std::multiplies<index_t>());
GemmM = a_grid_desc_m_k.get_length(number<0>{});
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
}
@@ -185,17 +185,17 @@ struct GroupedConvBwdWeightKernelArgs
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -204,17 +204,17 @@ struct GroupedConvBwdWeightKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[2])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1]),
static_cast<index_t>(args.conv_filter_dilations_[2])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
k_batch = args.k_batch;
@@ -239,8 +239,8 @@ struct GroupedConvBwdWeightKernelArgs
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
GroupedConvTraitsType_::NDimSpatial>();
a_grid_desc_m_k = grid_descs.at(number<0>{});
b_grid_desc_n_k = grid_descs.at(number<1>{});
a_grid_desc_k_m = grid_descs.at(number<0>{});
b_grid_desc_k_n = grid_descs.at(number<1>{});
c_grid_desc_m_n = grid_descs.at(number<2>{});
group_stride_a = args.K_; // A: Out NDHWGK
@@ -251,17 +251,17 @@ struct GroupedConvBwdWeightKernelArgs
1,
std::multiplies<index_t>());
GemmM = a_grid_desc_m_k.get_length(number<0>{});
GemmN = b_grid_desc_n_k.get_length(number<0>{});
GemmK = a_grid_desc_m_k.get_length(number<1>{});
GemmM = a_grid_desc_k_m.get_length(number<1>{});
GemmN = b_grid_desc_k_n.get_length(number<1>{});
GemmK = a_grid_desc_k_m.get_length(number<0>{});
GemmBatch = args.G_;
}
using ABCGridDescs = remove_cvref_t<
decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
using ABCGridDescs = remove_cvref_t<decltype(
ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
using AGridDescMK = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
using BGridDescNK = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
using AGridDescKM = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
using BGridDescKN = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
using CGridDescMN = remove_cvref_t<decltype(ABCGridDescs{}[number<2>{}])>;
static constexpr index_t NonSpatialDims = 3;
@@ -285,8 +285,8 @@ struct GroupedConvBwdWeightKernelArgs
std::array<const void*, NumDTensor> ds_ptr;
void* wei_ptr;
AGridDescMK a_grid_desc_m_k;
BGridDescNK b_grid_desc_n_k;
AGridDescKM a_grid_desc_k_m;
BGridDescKN b_grid_desc_k_n;
CGridDescMN c_grid_desc_m_n;
long_index_t group_stride_a;
@@ -338,10 +338,6 @@ template <typename GroupedConvTraitsType_,
typename EpiloguePipeline_>
struct GroupedConvolutionBackwardWeightKernel
{
// Todo: Enable Vector Load Size > 1
static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
GroupedConvTraitsType_::VectorSizeB == 1);
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
@@ -380,12 +376,8 @@ struct GroupedConvolutionBackwardWeightKernel
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
"Not supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
// TODO: Change to and enable vector load
// static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not
// supported!"); static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not
// supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -592,12 +584,12 @@ struct GroupedConvolutionBackwardWeightKernel
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
const auto& a_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(a_ptr,
kargs.a_grid_desc_m_k); // A: out
kargs.a_grid_desc_k_m); // A: out
}();
const auto& b_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(b_ptr,
kargs.b_grid_desc_n_k); // B: in
kargs.b_grid_desc_k_n); // B: in
}();
const auto& c_tensor_view = [&]() {
@@ -628,16 +620,16 @@ struct GroupedConvolutionBackwardWeightKernel
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{} * k_batch),
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
number<TilePartitioner::MPerBlock>{}),
sequence<true, true>{});
}();
const auto& b_pad_view = [&]() {
const auto& b_tensor_view = views.at(I1);
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{} * k_batch),
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
number<TilePartitioner::NPerBlock>{}),
sequence<true, true>{});
}();
@@ -675,16 +667,16 @@ struct GroupedConvolutionBackwardWeightKernel
const auto& a_block_window = [&]() {
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, i_k});
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_k, i_m});
}();
const auto& b_block_window = [&]() {
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, i_k});
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_k, i_n});
}();
const auto ds_block_window = generate_tuple(

View File

@@ -41,13 +41,13 @@ struct GroupedConvFwdKernelArgs
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -124,15 +124,15 @@ struct GroupedConvFwdKernelArgs
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -140,13 +140,13 @@ struct GroupedConvFwdKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[1])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1])};
static_cast<index_t>(args.conv_filter_strides_[1])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1])};
static_cast<index_t>(args.input_left_pads_[1])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1])};
static_cast<index_t>(args.input_right_pads_[1])};
k_batch = args.k_batch;
@@ -216,17 +216,17 @@ struct GroupedConvFwdKernelArgs
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -235,17 +235,17 @@ struct GroupedConvFwdKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[2])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1]),
static_cast<index_t>(args.conv_filter_dilations_[2])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
k_batch = args.k_batch;
@@ -306,15 +306,15 @@ struct GroupedConvFwdKernelArgs
args.output_spatial_lengths_[2];
}
using AGridDescMK = remove_cvref_t<
decltype(ConvToGemmFwdTransformer{}
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
using BGridDescNK = remove_cvref_t<
decltype(ConvToGemmFwdTransformer{}
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
using CGridDescMN = remove_cvref_t<
decltype(ConvToGemmFwdTransformer{}
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
using AGridDescMK = remove_cvref_t<decltype(
ConvToGemmFwdTransformer{}
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
using BGridDescNK = remove_cvref_t<decltype(
ConvToGemmFwdTransformer{}
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
using CGridDescMN = remove_cvref_t<decltype(
ConvToGemmFwdTransformer{}
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
static constexpr index_t NonSpatialDims = 3;
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> in_g_n_c_wis_lengths;

View File

@@ -82,20 +82,14 @@ struct GroupedConvTraits
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
// TODO: Change to and enable vector load
// ck_tile::tensor_layout::gemm::RowMajor,
// ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdWeight =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
// TODO: Change to and enable vector load
// ck_tile::tensor_layout::gemm::ColumnMajor,
// ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;

View File

@@ -502,7 +502,7 @@ struct TransformConvBwdDataToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
number<VectorSizeC>{},
I1);
}
@@ -512,7 +512,7 @@ struct TransformConvBwdDataToGemm
// GKYXC
return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_),
make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1),
number<VectorSizeC>{},
number<VectorSizeB>{},
I1);
}
@@ -547,7 +547,7 @@ struct TransformConvBwdDataToGemm
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
number<VectorSizeC>{},
I1);
}
@@ -558,7 +558,7 @@ struct TransformConvBwdDataToGemm
return make_naive_tensor_descriptor(
make_tuple(K_, Z_, Y_, X_, C_),
make_tuple(C_ * X_ * Y_ * Z_, C_ * X_ * Y_, C_ * X_, C_, I1),
number<VectorSizeC>{},
number<VectorSizeB>{},
I1);
}
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
@@ -642,7 +642,7 @@ struct TransformConvBwdDataToGemm
make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
make_tuple(sequence<0>{}, sequence<1>{}));
// c: input
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
@@ -797,7 +797,7 @@ struct TransformConvBwdDataToGemm
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(sequence<1, 2, 0>{}, sequence<3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
make_tuple(sequence<0>{}, sequence<1>{}));
// c: input
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
@@ -999,7 +999,7 @@ struct TransformConvBwdDataToGemm
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
make_tuple(sequence<0>{}, sequence<1>{}));
// c: input
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(

View File

@@ -421,7 +421,6 @@ struct TransformConvBwdWeightToGemm
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
number<VectorSizeA>{},
@@ -463,9 +462,8 @@ struct TransformConvBwdWeightToGemm
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
}
@@ -480,7 +478,7 @@ struct TransformConvBwdWeightToGemm
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
@@ -506,9 +504,8 @@ struct TransformConvBwdWeightToGemm
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
}
@@ -577,7 +574,7 @@ struct TransformConvBwdWeightToGemm
make_tuple(make_merge_transform(make_tuple(X_, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
@@ -614,7 +611,7 @@ struct TransformConvBwdWeightToGemm
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
@@ -657,7 +654,7 @@ struct TransformConvBwdWeightToGemm
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}