mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
[CK_TILE] Grouped Convolution Backward Data Direct Load (#6624)
## Proposed changes Add Grouped Convolution Backward Data with Direct Load into DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 device implementation. This enables direct global memory loading (bypassing LDS) for the backward data convolution path on gfx950, following the same pattern used in both backward weight and forward convolution. Direct load convolution backward data improves performance by avoiding LDS round-trips for certain configurations on gfx950, which supports a wider range of instructions. Currently correctness is checked only at usage point, but should be extended to a standalone UT in the future.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -408,10 +408,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
? 4 / sizeof(BDataType)
|
||||
: BBlockTransferSrcScalarPerVector;
|
||||
|
||||
static constexpr bool ALdsScalarLoadToVgpr =
|
||||
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
|
||||
static constexpr bool BLdsScalarLoadToVgpr =
|
||||
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
|
||||
|
||||
// Note: Direct load use layout to create proper block and mmtile descriptor
|
||||
// TODO: Fix and verify RC layout for not direct load (currently it returns wrong results)
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
std::conditional_t<DirectLoad,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor>,
|
||||
std::conditional_t<DirectLoad,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor>,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -456,7 +467,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
DirectLoad>;
|
||||
DirectLoad,
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user