Grouped conv bwd wei explicit GEMM for odd C/K (#2306)

This commit is contained in:
Bartłomiej Kocot
2025-06-10 11:17:12 +02:00
committed by GitHub
parent 9fcf21a4ec
commit 7a83f1d510
20 changed files with 557 additions and 434 deletions

View File

@@ -185,7 +185,9 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors;
using CDataType_ = CDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<

View File

@@ -11,6 +11,8 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
namespace ck {
namespace tensor_operation {
@@ -48,7 +50,48 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
static constexpr bool IsTwoStageNeeded =
sizeof(WeiDataType) % 4 != 0 &&
DeviceGemmV3Op::CDEShuffleBlockTransferScalarPerVectors_::At(I0) % 2 != 0;
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
using TwoStageIntermediateType = typename DeviceGemmV3Op::CDataType_;
static constexpr index_t ElementwiseBlockSize = 256;
static constexpr index_t ElemsPerBlock = 256;
static auto GetElementwiseCGridDesc(index_t merged_filter_dims)
{
const auto padd_size = merged_filter_dims % ElemsPerBlock == 0
? 0
: ElemsPerBlock - merged_filter_dims % ElemsPerBlock;
const auto desc = make_naive_tensor_descriptor_packed(make_tuple(I1, merged_filter_dims));
return transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(I1),
make_right_pad_transform(merged_filter_dims, padd_size)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
using CElementwiseGridDesc = remove_cvref_t<decltype(GetElementwiseCGridDesc(I1))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<1, ElemsPerBlock>;
using GridwiseElementwiseCast = GridwiseElementwise<Tuple<CElementwiseGridDesc>,
Tuple<CElementwiseGridDesc>,
Tuple<const float*>,
Tuple<WeiDataType*>,
Block2TileMapElementwise,
WeiElementwiseOperation,
ElementwiseBlockSize,
I1,
ElemsPerBlock,
I1,
ElemsPerBlock / ElementwiseBlockSize,
Sequence<0, 1>,
Sequence<1>,
Sequence<1>,
I1,
I1>;
struct Argument : public BaseArgument
{
@@ -58,11 +101,11 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>&, // input
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>&,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
@@ -74,42 +117,114 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
: filter_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
input_right_pads_{input_right_pads},
p_wei_grid_{p_wei_grid}
{
constexpr index_t spatial_offset = 3;
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
index_t{1},
std::multiplies<>{});
const index_t M = e_g_k_c_xs_lengths[I1];
const index_t N = e_g_k_c_xs_lengths[I2];
const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
const index_t BatchSize = a_g_n_k_wos_lengths[I0];
const index_t M = e_g_k_c_xs_lengths[I1];
const index_t N = e_g_k_c_xs_lengths[I2];
const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
explicit_gemm_args = GemmArgument{p_out_grid,
p_in_grid,
{},
p_wei_grid,
M,
N,
K,
BatchSize * M,
BatchSize * N,
{},
N,
M,
N,
{},
M * N,
BatchSize,
out_element_op,
in_element_op,
wei_element_op,
split_k};
const index_t StrideOut = a_g_n_k_wos_strides[spatial_offset + NDimSpatial - 1];
const index_t StrideIn = b_g_n_c_wis_strides[spatial_offset + NDimSpatial - 1];
const index_t StrideWei = e_g_k_c_xs_strides[I1];
const index_t StrideBatchOut = a_g_n_k_wos_strides[I0];
const index_t StrideBatchIn = b_g_n_c_wis_strides[I0];
const index_t StrideBatchWei = e_g_k_c_xs_strides[I0];
const index_t BatchSize = a_g_n_k_wos_lengths[I0];
std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
end(e_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
if constexpr(IsTwoStageNeeded)
{
const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths),
end(e_g_k_c_xs_lengths),
index_t{1},
std::multiplies<>{});
elementwise_desc_ = GetElementwiseCGridDesc(merged_filter_dims);
elementwise_block_2_ctile_map_ = Block2TileMapElementwise{1, merged_filter_dims};
// Check if stride to last dimension is product of all other dimensions. Then it is
// packed.
is_filter_data_packed =
e_g_k_c_xs_strides[0] == (merged_filter_dims / e_g_k_c_xs_lengths[0]);
// Data type is modified during launch. It is checked in IsSupported if user
// allocated workspace
explicit_gemm_args = GemmArgument{p_out_grid,
p_in_grid,
{},
static_cast<TwoStageIntermediateType*>(nullptr),
M,
N,
K,
StrideOut,
StrideIn,
{},
StrideWei,
StrideBatchOut,
StrideBatchIn,
{},
StrideBatchWei,
BatchSize,
out_element_op,
in_element_op,
wei_element_op,
split_k};
}
else
{
explicit_gemm_args = GemmArgument{p_out_grid,
p_in_grid,
{},
p_wei_grid,
M,
N,
K,
StrideOut,
StrideIn,
{},
StrideWei,
StrideBatchOut,
StrideBatchIn,
{},
StrideBatchWei,
BatchSize,
out_element_op,
in_element_op,
wei_element_op,
split_k};
}
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(IsTwoStageNeeded)
{
return sizeof(TwoStageIntermediateType) * elementwise_desc_.GetElementSpaceSize();
}
else
{
return 0;
}
}
std::size_t GetWorkspaceSizeBytes() const
{
if constexpr(IsTwoStageNeeded)
{
return GetWorkspaceETensorSizeBytes();
}
else
{
return 0;
}
}
GemmArgument explicit_gemm_args;
@@ -117,16 +232,56 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
WeiDataType* p_wei_grid_;
bool is_filter_data_packed;
CElementwiseGridDesc elementwise_desc_;
Block2TileMapElementwise elementwise_block_2_ctile_map_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
using Argument = DeviceOp::Argument;
using GemmArgument = typename DeviceGemmV3Op::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
if constexpr(IsTwoStageNeeded)
{
// Modify to use workspace as output
GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args;
explicit_gemm_args_with_workspace.p_c_grid =
static_cast<TwoStageIntermediateType*>(arg.p_workspace_);
float avg_time =
explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config);
const index_t grid_size =
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.elementwise_desc_);
const auto kernel = kernel_elementwise<GridwiseElementwiseCast,
ck::Tuple<CElementwiseGridDesc>,
ck::Tuple<CElementwiseGridDesc>,
ck::Tuple<const TwoStageIntermediateType*>,
ck::Tuple<WeiDataType*>,
Block2TileMapElementwise,
WeiElementwiseOperation>;
avg_time += launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(ElementwiseBlockSize),
0,
make_tuple(arg.elementwise_desc_),
make_tuple(arg.elementwise_desc_),
make_tuple(static_cast<const TwoStageIntermediateType*>(arg.p_workspace_)),
make_tuple(arg.p_wei_grid_),
arg.elementwise_block_2_ctile_map_,
element_wise::PassThrough{});
return avg_time;
}
else
{
return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
}
}
float Run(const BaseArgument* p_arg,
@@ -174,6 +329,26 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
return false;
}
}
if constexpr(IsTwoStageNeeded)
{
if(!arg.is_filter_data_packed)
{
return false;
}
// Check this here, it allows to use other instances from factory even
// if workspace is not allocated
if(!arg.p_workspace_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Warning: Workspace for "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer."
<< std::endl;
}
return false;
}
}
// Gridwise GEMM size
return DeviceGemmV3Op::IsSupportedArgument(arg.explicit_gemm_args);
}
@@ -277,6 +452,33 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->GetWorkspaceSizeBytes();
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->p_workspace_ = p_workspace;
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
}
};
} // namespace device