Merge branch 'vpietila/ckb-fwd-bwd-instances' of github.com:ROCm/composable_kernel into vpietila/ckb-fwd-bwd-instances

This commit is contained in:
Ville Pietilä
2025-11-04 07:52:27 -06:00

View File

@@ -48,15 +48,13 @@ enum class GroupConvLayout3D
NGCDHW_GKCZYX_NGKDHW,
};
struct GroupConvLayout
{
union
{
struct GroupConvLayout {
union {
GroupConvLayout1D _1d;
GroupConvLayout2D _2d;
GroupConvLayout3D _3d;
};
constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {}
constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {}
constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {}
@@ -83,34 +81,32 @@ enum class FwdGroupConvDeviceOperation
// Backward data convolution device operations.
enum class BwdDataGroupConvDeviceOperation
{
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1,
DeviceGroupedConvBwdDataMultipleD,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
};
// Backward weight convolution device operations.
enum class BwdWeightGroupConvDeviceOperation
{
DeviceGroupedConvBwdWeight,
DeviceGroupedConvBwdWeight_Dl,
DeviceGroupedConvBwdWeight_Xdl_CShuffle,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3,
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
DeviceGroupedConvBwdWeightMultipleD,
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
DeviceGroupedConvBwdWeight_Xdl_CShuffle,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3,
DeviceGroupedConvBwdWeightMultipleD,
DeviceGroupedConvBwdWeight_Dl
};
// Structural type for device operation
struct GroupConvDeviceOp
{
union
{
struct GroupConvDeviceOp {
union {
FwdGroupConvDeviceOperation _fwd;
BwdDataGroupConvDeviceOperation _bwd_data;
BwdWeightGroupConvDeviceOperation _bwd_weight;
};
constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {}
constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {}
constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {}