Fix template parameter macros (#3305)

Some of the device implementation templates have macros like GridwiseGemmMultiABDTemplateParameters that can cause build errors if multiple files are included together. This error comes up with our builder code.

To clean up the macros and make them safer, we follow these follow rules:
* Use more specific names to avoid duplication.
* Undefine the macro after it is used to avoid leaking out of the file scope.
* Use a prefix CK_ on the macro to avoid conflicting with other libraries.
* Use all caps with underscores for preprocessor macro names.
This commit is contained in:
John Shumway
2025-11-26 09:48:17 -08:00
committed by GitHub
parent 35a4b26af0
commit 10a782d846
4 changed files with 32 additions and 22 deletions

View File

@@ -446,7 +446,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using GemmADataType = ck::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
using GemmBDataType = ck::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
#define GridwiseGemmMultiABDTemplateParameters \
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
@@ -462,7 +462,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
#define GridwiseGemmTemplateParameters \
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
@@ -480,8 +480,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <index_t NXdlPerWave_>
using GridwiseGemmBase = ck::conditional_t<
isMultiA || isMultiB,
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
GridwiseGemmMultipleABD_xdl_cshuffle<CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS>,
GridwiseGemmMultipleD_xdl_cshuffle<CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS>>;
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;

View File

@@ -439,7 +439,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// GridwiseGemm
#define GridwiseGemmMultiDTemplateParams \
#define CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS \
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \
@@ -454,7 +454,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
#define GridwiseGemmCTransposeTemplateParameters \
#define CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS \
ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave_, MXdlPerWave, \
@@ -470,10 +470,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>;
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS>;
template <index_t NXdlPerWave_>
using GridwiseGemmCTransposeBase =
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>;
using GridwiseGemmCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle<
CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS>;
#undef CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS
#undef CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;

View File

@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
#define GridwiseGemmMultiABDTemplateParameters \
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
@@ -502,7 +502,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType
#define GridwiseGemmTemplateParameters \
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
@@ -518,7 +518,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType, DoElementwiseBeforeCShuffle
#define GridwiseGemmCTransposeTemplateParameters \
#define CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \
GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
@@ -536,14 +536,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Use appropriate gridwise gemm
template <index_t NXdlPerWave_>
using GridwiseGemmMultipleABDBase =
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmMultiABDTemplateParameters>;
using GridwiseGemmMultipleABDBase = GridwiseGemmMultipleABD_xdl_cshuffle<
CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
template <index_t NXdlPerWave_>
using GridwiseGemmMultipleDBase =
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
using GridwiseGemmMultipleDBase = GridwiseGemmMultipleD_xdl_cshuffle<
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
template <index_t NXdlPerWave_>
using GridwiseGemmMultipleDCTransposeBase =
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>;
using GridwiseGemmMultipleDCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle<
CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>;
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
#undef CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS
using GridwiseGemm64 =
std::conditional_t<isMultiA || isMultiB,

View File

@@ -405,7 +405,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
is_split_valid);
}
#define GridwiseGemmTemplateParameters \
#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS \
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
@@ -422,9 +422,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
AComputeDataType, DoElementwiseBeforeCShuffle
// Use appropriate gridwise gemm
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS>;
#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =