Grouped Convolution Backward Weight Explicit GEMM (#2282)

* Grouped conv bwd weight explicit gemm

* 3d

* cmake fixes

* fix test

* fix
This commit is contained in:
Bartłomiej Kocot
2025-06-06 10:30:08 +02:00
committed by GitHub
parent 00247e3c29
commit 050cad09b5
33 changed files with 2539 additions and 115 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -242,6 +242,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
@@ -282,7 +283,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
const std::array<ck::index_t, NumDTensor> BatchStrideDs_;
std::array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideC_;
};
@@ -291,6 +292,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t Batch;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
Argument() = default;
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
@@ -413,19 +415,39 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
const auto clear_workspace = [&]() {
if(arg.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg.p_c_grid,
0,
arg.Batch * arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
ave_time = launch_and_time_kernel_with_preprocess(stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg);
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{

View File

@@ -0,0 +1,284 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
typename DeviceGemmV3Op>
struct DeviceGroupedConvBwdWeight_Explicit_Xdl
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
static_assert(is_same_v<InElementwiseOperation, element_wise::PassThrough>);
static_assert(is_same_v<WeiElementwiseOperation, element_wise::PassThrough>);
static_assert(is_same_v<OutElementwiseOperation, element_wise::PassThrough>);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
struct Argument : public BaseArgument
{
using GemmArgument = typename DeviceGemmV3Op::Argument;
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>&, // input
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>&,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>&,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
ck::index_t split_k)
: filter_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
constexpr index_t spatial_offset = 3;
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
index_t{1},
std::multiplies<>{});
const index_t M = e_g_k_c_xs_lengths[I1];
const index_t N = e_g_k_c_xs_lengths[I2];
const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
const index_t BatchSize = a_g_n_k_wos_lengths[I0];
explicit_gemm_args = GemmArgument{p_out_grid,
p_in_grid,
{},
p_wei_grid,
M,
N,
K,
BatchSize * M,
BatchSize * N,
{},
N,
M,
N,
{},
M * N,
BatchSize,
out_element_op,
in_element_op,
wei_element_op,
split_k};
std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
end(e_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
}
GemmArgument explicit_gemm_args;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
typename DeviceGemmV3Op::Invoker explicit_gemm_op;
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else if constexpr(NDimSpatial == 3)
{
if constexpr(!is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else
{
return false;
}
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
{
return false;
}
}
// Gridwise GEMM size
return DeviceGemmV3Op::IsSupportedArgument(arg.explicit_gemm_args);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const ck::index_t split_k)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
b_g_n_c_wis_lengths, // input
b_g_n_c_wis_strides,
e_g_k_c_xs_lengths, // weight
e_g_k_c_xs_strides,
a_g_n_k_wos_lengths, // output
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid,
const void* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const ck::index_t split_k) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid),
b_g_n_c_wis_lengths, // input
b_g_n_c_wis_strides,
e_g_k_c_xs_lengths, // weight
e_g_k_c_xs_strides,
a_g_n_k_wos_lengths, // output
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeight_Explicit_Xdl"
<< "<" << DeviceGemmV3Op{}.GetTypeString() << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -391,53 +391,53 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using CElementwiseGridDesc_M_N =
remove_cvref_t<decltype(GetElementwiseCGridDesc<NDimSpatial>())>;
using GridwiseGemm =
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
AccDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
AccDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;

View File

@@ -328,53 +328,53 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm =
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
CDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
CDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =