[CK_TILE] Grouped Convolution Backward Data Direct Load (#6624)

## Proposed changes

Add Grouped Convolution Backward Data with Direct Load into
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 device implementation.
This enables direct global memory loading (bypassing LDS) for the
backward data convolution path on gfx950, following the same pattern
used in both backward weight and forward convolution.

Direct load convolution backward data improves performance by avoiding
LDS round-trips for certain configurations on gfx950, which supports a
wider range of instructions. Currently correctness is checked only at
usage point, but should be extended to a standalone UT in the future.
This commit is contained in:
jakpiase
2026-04-23 11:16:55 +02:00
committed by GitHub
parent eaeba5266b
commit 7f14d346f1
12 changed files with 1676 additions and 136 deletions

View File

@@ -408,10 +408,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
? 4 / sizeof(BDataType)
: BBlockTransferSrcScalarPerVector;
static constexpr bool ALdsScalarLoadToVgpr =
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
static constexpr bool BLdsScalarLoadToVgpr =
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
// Note: Direct load use layout to create proper block and mmtile descriptor
// TODO: Fix and verify RC layout for not direct load (currently it returns wrong results)
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
std::conditional_t<DirectLoad,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor>,
std::conditional_t<DirectLoad,
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor>,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
@@ -456,7 +467,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
DirectLoad>;
DirectLoad,
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;