mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,981 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST void LogInfo(Args&&... args) noexcept
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_INFO(std::forward<Args>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief The Grouped Convolution kernel device arguments.
|
||||
template <typename GroupedConvTraitsType_>
|
||||
struct GroupedConvBwdWeightKernelArgs
|
||||
{
|
||||
|
||||
using ConvToGemmTransformer =
|
||||
TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
|
||||
GroupedConvTraitsType_::ConvSpecialization,
|
||||
GroupedConvTraitsType_::VectorSizeA,
|
||||
GroupedConvTraitsType_::VectorSizeB,
|
||||
GroupedConvTraitsType_::VectorSizeC,
|
||||
GroupedConvTraitsType_::NumGroupsToMerge>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType_::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType_::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
// tuple
|
||||
auto grid_descs =
|
||||
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
|
||||
GroupedConvTraitsType_::NDimSpatial>();
|
||||
|
||||
a_grid_desc_k_m = grid_descs.at(number<0>{});
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKXC
|
||||
* NumGroupsPerBatch *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
LogInfo("GemmM: ",
|
||||
GemmM,
|
||||
", GemmN: ",
|
||||
GemmN,
|
||||
", GemmK: ",
|
||||
GemmK,
|
||||
", GemmBatch: ",
|
||||
GemmBatch,
|
||||
", NumGroupsPerBatch: ",
|
||||
NumGroupsPerBatch,
|
||||
", k_batch: ",
|
||||
k_batch);
|
||||
}
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType_::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType_::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
// tuple
|
||||
auto grid_descs =
|
||||
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
|
||||
GroupedConvTraitsType_::NDimSpatial>();
|
||||
|
||||
a_grid_desc_k_m = grid_descs.at(number<0>{});
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
|
||||
* NumGroupsPerBatch *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
LogInfo("GemmM: ",
|
||||
GemmM,
|
||||
", GemmN: ",
|
||||
GemmN,
|
||||
", GemmK: ",
|
||||
GemmK,
|
||||
", GemmBatch: ",
|
||||
GemmBatch,
|
||||
", NumGroupsPerBatch: ",
|
||||
NumGroupsPerBatch,
|
||||
", k_batch: ",
|
||||
k_batch);
|
||||
}
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType_::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType_::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[2])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[2])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
// tuple
|
||||
auto grid_descs =
|
||||
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
|
||||
GroupedConvTraitsType_::NDimSpatial>();
|
||||
|
||||
a_grid_desc_k_m = grid_descs.at(number<0>{});
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
|
||||
* NumGroupsPerBatch *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
LogInfo("GemmM: ",
|
||||
GemmM,
|
||||
", GemmN: ",
|
||||
GemmN,
|
||||
", GemmK: ",
|
||||
GemmK,
|
||||
", GemmBatch: ",
|
||||
GemmBatch,
|
||||
", NumGroupsPerBatch: ",
|
||||
NumGroupsPerBatch,
|
||||
", k_batch: ",
|
||||
k_batch);
|
||||
}
|
||||
|
||||
using ABCGridDescs = remove_cvref_t<
|
||||
decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
|
||||
|
||||
using AGridDescKM = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
|
||||
using BGridDescKN = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
|
||||
using CGridDescMN = remove_cvref_t<decltype(ABCGridDescs{}[number<2>{}])>;
|
||||
|
||||
static constexpr index_t NonSpatialDims = 3;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> in_g_n_c_wis_lengths;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> wei_g_k_c_xs_lengths;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> out_g_n_k_wos_lengths;
|
||||
|
||||
array<index_t, GroupedConvTraitsType_::NDimSpatial> conv_filter_strides;
|
||||
array<index_t, GroupedConvTraitsType_::NDimSpatial> conv_filter_dilations;
|
||||
array<index_t, GroupedConvTraitsType_::NDimSpatial> input_left_pads;
|
||||
array<index_t, GroupedConvTraitsType_::NDimSpatial> input_right_pads;
|
||||
|
||||
index_t k_batch;
|
||||
index_t GemmM;
|
||||
index_t GemmN;
|
||||
index_t GemmK;
|
||||
index_t GemmBatch;
|
||||
index_t NumGroupsPerBatch;
|
||||
|
||||
const void* out_ptr;
|
||||
const void* in_ptr;
|
||||
std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* wei_ptr;
|
||||
|
||||
AGridDescKM a_grid_desc_k_m;
|
||||
BGridDescKN b_grid_desc_k_n;
|
||||
CGridDescMN c_grid_desc_m_n;
|
||||
|
||||
long_index_t group_stride_a;
|
||||
long_index_t group_stride_b;
|
||||
long_index_t group_stride_c;
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Backward Weight kernel template.
|
||||
///
|
||||
/// @paragraph Overview Overview
|
||||
/// This class provides the grouped convolution backward weight kernel template. By
|
||||
/// semantic division of Implicit GEMM algorithm into following parts we achieve
|
||||
/// flexible, versatile and robust kernel implementation.
|
||||
///
|
||||
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
|
||||
/// function call operator" which determines the work scope of each workgroup.
|
||||
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
|
||||
/// This is the place where each workgroup is loading data from global memory and
|
||||
/// carrying out dot products.
|
||||
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
|
||||
/// responsible for storing results to global memory. This is also the place where
|
||||
/// any additional operator fusion may take place.
|
||||
///
|
||||
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
|
||||
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
|
||||
/// internal details of those functional parts. You can think of it like both gemm and
|
||||
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
|
||||
/// the policy is responsible for definition of all necessary data layouts and thread's
|
||||
/// work distribution.
|
||||
///
|
||||
/// @tparam GroupedConvTraitsType_ The type of class providing traits for grouped convolution.
|
||||
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
|
||||
/// the output data tile to be calculated. It determines the
|
||||
/// workgroup to data relationship (or in other words - which
|
||||
/// data would be processed and calculated by which workgroup).
|
||||
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
|
||||
/// multiplication. This class should provide implementation of
|
||||
/// data loading from global memory and performing block-wise
|
||||
/// matrix multiplication. You can think of it as a work done by
|
||||
/// single workgroup point of view.
|
||||
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <typename GroupedConvTraitsType_,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using GemmALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using GemmBLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using GemmCLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
using InLayout = remove_cvref_t<typename GroupedConvTraitsType_::InLayout>;
|
||||
using WeiLayout = remove_cvref_t<typename GroupedConvTraitsType_::WeiLayout>;
|
||||
using OutLayout = remove_cvref_t<typename GroupedConvTraitsType_::OutLayout>;
|
||||
using DsLayout = remove_cvref_t<typename GroupedConvTraitsType_::DsLayout>;
|
||||
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using OutDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
using WeiDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using GroupedConvBwdWeightKernelArgsSpecialized =
|
||||
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
|
||||
|
||||
static constexpr bool IsSplitKSupported = true;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
|
||||
"Not supported!");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_backward_weight",
|
||||
gemm_prec_str<InDataType, WeiDataType>(),
|
||||
InLayout::name,
|
||||
WeiLayout::name,
|
||||
OutLayout::name,
|
||||
"gemm",
|
||||
GemmPipeline::GetName(),
|
||||
"epilogue",
|
||||
EpiloguePipeline::GetName(),
|
||||
getConvSpecializationString(ConvSpecialization),
|
||||
"MergedGroups",
|
||||
NumGroupsToMerge,
|
||||
"SplitImage",
|
||||
EnableSplitImage,
|
||||
"ExplicitGemm",
|
||||
GroupedConvTraitsType_::ExplicitGemm
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
CK_TILE_HOST std::string GetInstanceString() const
|
||||
{
|
||||
static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardWeightKernel>,
|
||||
"Specialization of instance_traits not found. Please check that a "
|
||||
"specialization exists in file "
|
||||
"ck_tile/builder/reflect/"
|
||||
"instance_traits_tile_grouped_convolution_backward_weight.hpp "
|
||||
"for the given template parameters.");
|
||||
return ck_tile::reflect::instance_string<GroupedConvolutionBackwardWeightKernel>();
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
return dim3(
|
||||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
|
||||
{
|
||||
LogInfo("MPerBlock: ",
|
||||
number<TilePartitioner::MPerBlock>{},
|
||||
", NPerBlock: ",
|
||||
number<TilePartitioner::NPerBlock>{},
|
||||
", KPerBlock: ",
|
||||
number<TilePartitioner::KPerBlock>{});
|
||||
|
||||
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
|
||||
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
|
||||
TilePartitioner_,
|
||||
GemmPipeline_,
|
||||
EpiloguePipeline_>;
|
||||
|
||||
// Negative k_batch value: split-K autodeduction.
|
||||
if(kernel_args.k_batch < 0)
|
||||
{
|
||||
const auto optimal_split_k =
|
||||
calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
|
||||
kernel_args);
|
||||
kernel_args.k_batch = optimal_split_k;
|
||||
}
|
||||
|
||||
return kernel_args;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if(kargs.k_batch < 1)
|
||||
{
|
||||
LogInfo("k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<typename EpiloguePipeline::ODataType, float> &&
|
||||
!std::is_same_v<typename EpiloguePipeline::ODataType, double>)
|
||||
{
|
||||
// The epilogue performs atomic add related to split-K using the ODataType.
|
||||
// If the type is less accurate than float, large split-K values may lead to
|
||||
// accuracy issues. Hence, we limit the maximum split-K value to 128 in such cases.
|
||||
if(kargs.k_batch > 128)
|
||||
{
|
||||
LogInfo("For epilogue output data type that is not float/double, we must have "
|
||||
"k_batch <= 128.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
LogInfo("Conditions not met for K_batch > 1: VectorSizeC must be a multiple of 2 "
|
||||
"for fp16/bf16 when K_batch > 1.",
|
||||
"Now k_batch is ",
|
||||
kargs.k_batch,
|
||||
", VectorSizeC is ",
|
||||
GroupedConvTraitsType_::VectorSizeC);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
|
||||
{
|
||||
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
|
||||
kargs.GemmK,
|
||||
", BlockGemmShape K: ",
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}),
|
||||
", k_batch: ",
|
||||
kargs.k_batch);
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
|
||||
// check ConvSpecialization
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t ConvStride = kargs.conv_filter_strides[i];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
LogInfo("For Filter1x1Stride1Pad0 specialization, all spatial dimensions must "
|
||||
"be 1, stride must be 1, and padding must be 0. Now for dimension ",
|
||||
i,
|
||||
": SpatialDim is ",
|
||||
SpatialDim,
|
||||
", ConvStride is ",
|
||||
ConvStride,
|
||||
", LeftPad is ",
|
||||
LeftPad,
|
||||
", RightPad is ",
|
||||
RightPad);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
LogInfo("For Filter1x1Pad0 specialization, all spatial dimensions must be 1 "
|
||||
"and padding must be 0. Now for dimension ",
|
||||
i,
|
||||
": SpatialDim is ",
|
||||
SpatialDim,
|
||||
", LeftPad is ",
|
||||
LeftPad,
|
||||
", RightPad is ",
|
||||
RightPad);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
|
||||
{
|
||||
if(ConvC != 1)
|
||||
{
|
||||
LogInfo("For Filter3x3 specialization, ConvC must be 1. Now ConvC is ", ConvC);
|
||||
return false;
|
||||
}
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
|
||||
|
||||
if(filter_spatial_dim != I3)
|
||||
{
|
||||
LogInfo("For Filter3x3 specialization, all spatial dimensions of the filter "
|
||||
"must be 3. Now for dimension ",
|
||||
i,
|
||||
", filter_spatial_dim is ",
|
||||
filter_spatial_dim);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
|
||||
ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
LogInfo("ExplicitGemm is only supported for Filter1x1Stride1Pad0 specialization.");
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
// Check access per C
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
|
||||
{
|
||||
LogInfo("Conv C is not a multiple of vector load size for input! ConvC: ",
|
||||
ConvC,
|
||||
", VectorSizeB: ",
|
||||
GroupedConvTraitsType_::VectorSizeB);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
LogInfo("Not supported input layout! Now InLayout is ", InLayout::name);
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
|
||||
{
|
||||
if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
|
||||
{
|
||||
LogInfo("Conv C is not a multiple of vector load size for weight! ConvC: ",
|
||||
ConvC,
|
||||
", VectorSizeC: ",
|
||||
GroupedConvTraitsType_::VectorSizeC);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
LogInfo("Not supported weight layout! Now WeiLayout is ", WeiLayout::name);
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NDHWGK>)
|
||||
{
|
||||
if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
|
||||
{
|
||||
LogInfo("Conv K is not a multiple of vector load size for output! ConvK: ",
|
||||
ConvK,
|
||||
", VectorSizeA: ",
|
||||
GroupedConvTraitsType_::VectorSizeA);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
LogInfo("Not supported output layout! Now OutLayout is ", OutLayout::name);
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
|
||||
if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
|
||||
{
|
||||
LogInfo("Number of groups must be divisible by NumGroupsToMerge! ConvG: ",
|
||||
ConvG,
|
||||
", NumGroupsToMerge: ",
|
||||
GroupedConvTraitsType_::NumGroupsToMerge);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeCBlockWindow(WeiDataType* c_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
const auto& c_tensor_view =
|
||||
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, kargs.c_grid_desc_m_n);
|
||||
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_m, block_idx_n});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_m, block_idx_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindow(const InDataType* b_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& b_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_k_n);
|
||||
|
||||
const auto& b_pad_view =
|
||||
pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{block_idx_k, block_idx_n});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeABlockWindow(const OutDataType* a_ptr,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_k_m);
|
||||
|
||||
const auto& a_pad_view =
|
||||
pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::MPerBlock>{}),
|
||||
{block_idx_k, block_idx_m});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs Grouped Convolution Backward Weight kernel arguments
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
|
||||
const InDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
const index_t num_loop,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t block_idx_k)
|
||||
{
|
||||
// Create block windows using helper methods
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k);
|
||||
const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline with k_batch dispatching
|
||||
if(kargs.k_batch == 1)
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::set>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<WeiDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
static_assert(NumDTensor == 0, "Not supported!");
|
||||
using ExplicitBatchedGemmKernel =
|
||||
BatchedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
|
||||
{{kargs.out_ptr},
|
||||
{kargs.in_ptr},
|
||||
{},
|
||||
kargs.wei_ptr,
|
||||
kargs.GemmM,
|
||||
kargs.GemmN,
|
||||
kargs.GemmK,
|
||||
{kargs.GemmM * kargs.GemmBatch},
|
||||
{kargs.GemmN * kargs.GemmBatch},
|
||||
{},
|
||||
kargs.GemmN,
|
||||
kargs.k_batch},
|
||||
kargs.GemmM,
|
||||
kargs.GemmN,
|
||||
kargs.GemmM * kargs.GemmN,
|
||||
kargs.GemmBatch};
|
||||
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
|
||||
{
|
||||
CallExplicitGemm(kargs);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
|
||||
const index_t num_loop = amd_wave_read_first_lane(ck_tile::integer_divide_ceil(
|
||||
kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
|
||||
const index_t i_k =
|
||||
amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
|
||||
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
// options
|
||||
// conv_bwd_weight = Out * In = Weight
|
||||
const OutDataType* a_ptr =
|
||||
static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
|
||||
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
|
||||
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
|
||||
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Disable Async for other archs than gfx950
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,208 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// UniversalGemm Policy
|
||||
struct GroupedConvUniversalPipelineAgBgCrPolicy
|
||||
: public UniversalGemmBasePolicy<GroupedConvUniversalPipelineAgBgCrPolicy>
|
||||
{
|
||||
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using ADataType = OverrideADataType;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better lds descriptor for performance
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
|
||||
make_tuple(number<MPerBlock>{}, number<1>{}),
|
||||
number<MPerBlock>{},
|
||||
number<1>{});
|
||||
return a_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto MLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
number<MPerBlock / MLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer * RowMul>{},
|
||||
number<KPerBlock / KPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create LDS block descriptor for B tensor.
|
||||
*
|
||||
* @tparam Problem Gemm pipeline problem.
|
||||
* @return B tensor LDS block descriptor.
|
||||
*/
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(is_b_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better lds descriptor for performance
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
|
||||
make_tuple(number<NPerBlock>{}, number<1>{}),
|
||||
number<NPerBlock>{},
|
||||
number<1>{});
|
||||
return b_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto NLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
BK0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer * RowMul>{},
|
||||
BK0 * number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_nldslayer_n_bk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t vector_size =
|
||||
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
|
||||
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
|
||||
constexpr auto wg_attr_num_access =
|
||||
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
|
||||
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
|
||||
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_t> ||
|
||||
sizeof(BDataType) < sizeof(ADataType),
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct ConvolutionSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Pad0,
|
||||
Filter1x1Stride1Pad0,
|
||||
Filter3x3,
|
||||
};
|
||||
|
||||
CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ConvolutionSpecialization::Default: return "Default";
|
||||
case ConvolutionSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionSpecialization::Filter3x3: return "Filter3x3";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,261 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class GroupedConvDirection
|
||||
{
|
||||
FORWARD,
|
||||
BACKWARD_DATA,
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
/// @brief The Grouped Conv kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to Grouped Convolution Kernels when creating kernel
|
||||
/// arguments object. It contain all necessary information required to
|
||||
/// build proper kernel argument and launch kernel on GPU.
|
||||
template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
|
||||
struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
CK_TILE_HOST GroupedConvHostArgs() = delete;
|
||||
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param,
|
||||
InPtr in_ptr_,
|
||||
WeiPtr wei_ptr_,
|
||||
const std::vector<const void*> ds_ptr_,
|
||||
OutPtr out_ptr_,
|
||||
index_t k_batch_,
|
||||
CDElementwise elfunc_ = CDElementwise{})
|
||||
: conv::ConvParam(conv_param),
|
||||
in_ptr(in_ptr_),
|
||||
wei_ptr(wei_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
out_ptr(out_ptr_),
|
||||
k_batch(k_batch_),
|
||||
elfunc(elfunc_)
|
||||
{
|
||||
}
|
||||
|
||||
InPtr in_ptr;
|
||||
WeiPtr wei_ptr;
|
||||
const std::vector<const void*> ds_ptr;
|
||||
OutPtr out_ptr;
|
||||
index_t k_batch;
|
||||
const CDElementwise elfunc;
|
||||
};
|
||||
|
||||
using PassThrough = ck_tile::element_wise::PassThrough;
|
||||
|
||||
template <typename CDElementwise = PassThrough>
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*, CDElementwise>;
|
||||
using GroupedConvBwdWeightHostArgs =
|
||||
GroupedConvHostArgs<const void*, void*, const void*, PassThrough>;
|
||||
using GroupedConvBwdDataHostArgs =
|
||||
GroupedConvHostArgs<void*, const void*, const void*, PassThrough>;
|
||||
|
||||
template <index_t NDimSpatial_,
|
||||
ConvolutionSpecialization ConvSpecialization_,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
index_t NumGroupsToMerge_ = 1,
|
||||
bool EnableSplitImage_ = false,
|
||||
bool ExplicitGemm_ = false>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
static constexpr auto generate_implicit_gemm_layout()
|
||||
{
|
||||
return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; },
|
||||
number<DsLayout_::size()>{});
|
||||
}
|
||||
|
||||
public:
|
||||
// Fixed values for Implicit GEMM
|
||||
struct FixedGemmParams
|
||||
{
|
||||
static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool FixedVectorSize = true;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool Persistent = false;
|
||||
using ELayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
};
|
||||
// Compile time parameters
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr bool ExplicitGemm = ExplicitGemm_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
|
||||
// Forward Gemm Layouts
|
||||
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Data Gemm Layouts
|
||||
using AsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BsLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using CLayoutBwdData = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
// Backward Weight Gemm Layouts
|
||||
using AsLayoutBwdWeight = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
template <GroupedConvDirection Direction>
|
||||
struct GemmLayouts
|
||||
{
|
||||
static_assert(false, "Unsupported direction.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmLayouts<GroupedConvDirection::FORWARD>
|
||||
{
|
||||
using AsLayout = AsLayoutFwd;
|
||||
using BsLayout = BsLayoutFwd;
|
||||
using CLayout = CLayoutFwd;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmLayouts<GroupedConvDirection::BACKWARD_DATA>
|
||||
{
|
||||
using AsLayout = AsLayoutBwdData;
|
||||
using BsLayout = BsLayoutBwdData;
|
||||
using CLayout = CLayoutBwdData;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmLayouts<GroupedConvDirection::BACKWARD_WEIGHT>
|
||||
{
|
||||
using AsLayout = AsLayoutBwdWeight;
|
||||
using BsLayout = BsLayoutBwdWeight;
|
||||
using CLayout = CLayoutBwdWeight;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsFwd =
|
||||
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsBwdData = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
AsLayoutBwdData,
|
||||
BsLayoutBwdData,
|
||||
CLayoutBwdData,
|
||||
NumWaveGroups>;
|
||||
template <ck_tile::index_t NumWaveGroups = 1>
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
AsLayoutBwdWeight,
|
||||
BsLayoutBwdWeight,
|
||||
CLayoutBwdWeight,
|
||||
NumWaveGroups>;
|
||||
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
|
||||
static constexpr ck_tile::index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr ck_tile::index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
/// @brief Helper struct for split-image piece information
|
||||
///
|
||||
/// @par Overview
|
||||
/// Stores metadata for a single spatial piece in split-image convolution.
|
||||
/// Used to track block ranges and spatial coordinates for each piece.
|
||||
struct SplitImagePieceInfo
|
||||
{
|
||||
ck_tile::index_t block_start, block_end; ///< GPU block range for this piece
|
||||
ck_tile::index_t d_start, h_start, w_start; ///< Spatial start coordinates (output space)
|
||||
ck_tile::index_t d_size, h_size, w_size; ///< Spatial dimensions of this piece
|
||||
};
|
||||
|
||||
/// @brief Calculate piece information for split-image convolution
|
||||
///
|
||||
/// @par Overview
|
||||
/// Computes spatial coordinates, dimensions, and GPU block range for a single
|
||||
/// piece in split-image convolution. Handles edge pieces that may have different
|
||||
/// sizes due to non-uniform division.
|
||||
///
|
||||
/// @tparam TilePartitioner Type providing MPerBlock and NPerBlock constants
|
||||
///
|
||||
/// @param piece_idx Index of the piece to calculate (0-based)
|
||||
/// @param num_d_pieces Number of pieces in D dimension
|
||||
/// @param num_h_pieces Number of pieces in H dimension
|
||||
/// @param num_w_pieces Number of pieces in W dimension
|
||||
/// @param base_piece_d Base size of each D piece (may differ for last piece)
|
||||
/// @param base_piece_h Base size of each H piece (may differ for last piece)
|
||||
/// @param base_piece_w Base size of each W piece (may differ for last piece)
|
||||
/// @param total_d Total D dimension size (output space)
|
||||
/// @param total_h Total H dimension size (output space)
|
||||
/// @param total_w Total W dimension size (output space)
|
||||
/// @param N Batch size
|
||||
/// @param K Output channels
|
||||
/// @param total_blocks Accumulated block count from previous pieces
|
||||
///
|
||||
/// @return SplitImagePieceInfo containing all metadata for this piece
|
||||
template <typename TilePartitioner>
|
||||
CK_TILE_HOST SplitImagePieceInfo calculate_spatial_piece(ck_tile::index_t piece_idx,
|
||||
ck_tile::index_t num_d_pieces,
|
||||
ck_tile::index_t num_h_pieces,
|
||||
ck_tile::index_t num_w_pieces,
|
||||
ck_tile::index_t base_piece_d,
|
||||
ck_tile::index_t base_piece_h,
|
||||
ck_tile::index_t base_piece_w,
|
||||
ck_tile::index_t total_d,
|
||||
ck_tile::index_t total_h,
|
||||
ck_tile::index_t total_w,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t total_blocks)
|
||||
{
|
||||
// Unflatten piece index into 3D coordinates (W-major, then H, then D)
|
||||
const ck_tile::index_t w_idx = piece_idx % num_w_pieces;
|
||||
const ck_tile::index_t h_idx = (piece_idx / num_w_pieces) % num_h_pieces;
|
||||
const ck_tile::index_t d_idx = piece_idx / (num_w_pieces * num_h_pieces);
|
||||
|
||||
// Calculate spatial start positions
|
||||
const ck_tile::index_t w_start = w_idx * base_piece_w;
|
||||
const ck_tile::index_t h_start = h_idx * base_piece_h;
|
||||
const ck_tile::index_t d_start = d_idx * base_piece_d;
|
||||
|
||||
// Calculate piece sizes (last piece may be larger to cover remainder)
|
||||
const ck_tile::index_t w_size =
|
||||
(w_idx == num_w_pieces - 1) ? (total_w - w_start) : base_piece_w;
|
||||
const ck_tile::index_t h_size =
|
||||
(h_idx == num_h_pieces - 1) ? (total_h - h_start) : base_piece_h;
|
||||
const ck_tile::index_t d_size =
|
||||
(d_idx == num_d_pieces - 1) ? (total_d - d_start) : base_piece_d;
|
||||
|
||||
// Calculate GEMM dimensions for this piece
|
||||
const ck_tile::index_t piece_gemm_m = N * d_size * h_size * w_size;
|
||||
const ck_tile::index_t piece_gemm_n = K;
|
||||
|
||||
// Calculate GPU grid size for this piece
|
||||
const ck_tile::index_t piece_grid =
|
||||
((piece_gemm_m + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock) *
|
||||
((piece_gemm_n + TilePartitioner::NPerBlock - 1) / TilePartitioner::NPerBlock);
|
||||
|
||||
return {
|
||||
total_blocks, total_blocks + piece_grid, d_start, h_start, w_start, d_size, h_size, w_size};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,81 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
|
||||
CK_TILE_HOST index_t get_max_occupancy_for_kernel()
|
||||
{
|
||||
constexpr int dynamic_smem_size = 0;
|
||||
constexpr int min_blocks_per_cu = 1;
|
||||
|
||||
const auto kernel_ptr = kentry<min_blocks_per_cu, KernelImpl, KernelArgs>;
|
||||
|
||||
int max_occupancy = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy, kernel_ptr, BlockSize, dynamic_smem_size));
|
||||
|
||||
return static_cast<index_t>(max_occupancy);
|
||||
}
|
||||
|
||||
CK_TILE_HOST index_t get_best_occupancy_k_batch_value(index_t max_occupancy, index_t grid_size)
|
||||
{
|
||||
static const index_t num_cus = get_num_cus();
|
||||
const index_t max_capacity = max_occupancy * num_cus;
|
||||
|
||||
index_t k_batch = 1;
|
||||
const auto optimal_split = static_cast<index_t>(std::floor((1.0 * max_capacity) / grid_size));
|
||||
if(optimal_split > 1)
|
||||
{
|
||||
k_batch = optimal_split;
|
||||
}
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
|
||||
<< max_occupancy << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
|
||||
}
|
||||
return k_batch;
|
||||
}
|
||||
|
||||
template <index_t BlockSize, typename KernelArgs, typename KernelImpl>
|
||||
struct ActiveWorkgroupsPerCU
|
||||
{
|
||||
CK_TILE_HOST ActiveWorkgroupsPerCU()
|
||||
{
|
||||
max_occupancy_ = get_max_occupancy_for_kernel<BlockSize, KernelArgs, KernelImpl>();
|
||||
}
|
||||
index_t max_occupancy_{1};
|
||||
};
|
||||
|
||||
template <index_t BlockSize, typename KernelImpl, typename TilePartitioner, typename KernelArgs>
|
||||
CK_TILE_HOST index_t calculate_optimal_k_batch(const KernelArgs& kargs)
|
||||
{
|
||||
static ActiveWorkgroupsPerCU<BlockSize, KernelArgs, KernelImpl> active_workgroups_per_cu;
|
||||
|
||||
const auto grid_size = TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN) * kargs.GemmBatch;
|
||||
auto optimal_k_batch =
|
||||
get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size);
|
||||
|
||||
const auto max_allowed_k_batch = kargs.GemmK;
|
||||
optimal_k_batch = std::min(optimal_k_batch, max_allowed_k_batch);
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << optimal_k_batch << std::endl;
|
||||
}
|
||||
|
||||
return optimal_k_batch;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user