diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 29ccd7289f..5f60d8787d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -446,7 +446,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = ck::conditional_t, ADataType>; using GemmBDataType = ck::conditional_t, BDataType>; -#define GridwiseGemmMultiABDTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -462,7 +462,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched -#define GridwiseGemmTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ @@ -480,8 +480,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template using GridwiseGemmBase = ck::conditional_t< isMultiA || isMultiB, - GridwiseGemmMultipleABD_xdl_cshuffle, - GridwiseGemmMultipleD_xdl_cshuffle>; + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index b291b20bcd..d33e807828 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -439,7 +439,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // GridwiseGemm -#define GridwiseGemmMultiDTemplateParams \ +#define CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \ @@ -454,7 +454,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType -#define GridwiseGemmCTransposeTemplateParameters \ +#define CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave_, MXdlPerWave, \ @@ -470,10 +470,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType template - using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS>; template - using GridwiseGemmCTransposeBase = - GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS>; +#undef CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS +#undef CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 698af8846d..a9b0975050 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = std::conditional_t, ADataType>; using GemmBDataType = std::conditional_t, BDataType>; -#define GridwiseGemmMultiABDTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -502,7 +502,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType -#define GridwiseGemmTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ @@ -518,7 +518,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType, DoElementwiseBeforeCShuffle -#define GridwiseGemmCTransposeTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \ GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \ @@ -536,14 +536,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Use appropriate gridwise gemm template - using GridwiseGemmMultipleABDBase = - GridwiseGemmMultipleABD_xdl_cshuffle; + using GridwiseGemmMultipleABDBase = GridwiseGemmMultipleABD_xdl_cshuffle< + CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>; template - using GridwiseGemmMultipleDBase = - GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmMultipleDBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>; template - using GridwiseGemmMultipleDCTransposeBase = - GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmMultipleDCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>; +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS +#undef CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS using GridwiseGemm64 = std::conditional_t - using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS>; +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 =