diff --git a/experimental/builder/validation_rules/conv_fwd_device_ops_template_parameter_rules.md b/experimental/builder/validation_rules/conv_fwd_device_ops_template_parameter_rules.md new file mode 100644 index 0000000000..783da6f531 --- /dev/null +++ b/experimental/builder/validation_rules/conv_fwd_device_ops_template_parameter_rules.md @@ -0,0 +1,512 @@ +# Template Parameter Constraint Rules for Forward Convolution Device Operations + +This document lists all static_assert rules and runtime validation checks that constrain template parameter selection for the five forward convolution device operations in Composable Kernel. + +## Device Operations Analyzed + +1. **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3** +2. **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle** +3. **DeviceGroupedConvFwdMultipleD_Wmma_CShuffle** +4. **DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor** +5. **DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK** + +--- + +## 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + +### Gridwise Implementation +Uses: `GridwiseGemmMultiD_xdl_cshuffle_v3` + +### Compile-Time Static Asserts (Gridwise Level) + +#### Block and Wave Tiling Constraints +```cpp +static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); +``` +- **Rule**: MPerBlock must be divisible by (MPerXdl × MXdlPerWave) +- **Rule**: NPerBlock must be divisible by (NXdlPerWave × NPerXdl) + +#### Shuffle Constraints +```cpp +static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); +``` +- **Rule**: MXdlPerWave must be divisible by CShuffleMXdlPerWavePerShuffle +- **Rule**: NXdlPerWave must be divisible by CShuffleNXdlPerWavePerShuffle + +### Compile-Time Static Asserts (Blockwise Level) + +From `BlockwiseGemmXdlops`: +```cpp +static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && + NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + +static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + +static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); +``` +- **Rule**: MPerBlock must be divisible by (MPerXDL × MRepeat) +- **Rule**: NPerBlock must be divisible by (NPerXDL × NRepeat) +- **Rule**: KPerThread must be divisible by KPack +- **Rule**: BlockSize must equal MWaves × NWaves × WaveSize + +### Runtime Validation Checks + +#### Vector Access for A (Input) Tensor +For layouts G_NW_C, G_NHW_C, G_NDHW_C, GNWC, GNHWC, GNDHWC, NWGC, NHWGC, NDHWGC, NGCW, NGCHW, NGCDHW: +```cpp +C % ABlockTransferSrcScalarPerVector == 0 +``` +- **Rule**: C (input channels) must be divisible by ABlockTransferSrcScalarPerVector when ABlockTransferSrcVectorDim == 2 + +#### Vector Access for B (Weight) Tensor +For layouts G_K_X_C, G_K_YX_C, G_K_ZYX_C, GKXC, GKYXC, GKZYXC, KXGC, KYXGC, KZYXGC, GKCX, GKCYX, GKCZYX: +```cpp +C % BBlockTransferSrcScalarPerVector == 0 +``` +- **Rule**: C (input channels) must be divisible by BBlockTransferSrcScalarPerVector when BBlockTransferSrcVectorDim == 2 + +#### Vector Access for E (Output) Tensor +For layouts G_NW_K, G_NHW_K, G_NDHW_K, GNWK, GNHWK, GNDHWK, NWGK, NHWGK, NDHWGK, NGKW, NGKHW, NGKDHW: +```cpp +K % CDEBlockTransferScalarPerVector_NPerBlock == 0 +``` +- **Rule**: K (output channels) must be divisible by CDEBlockTransferScalarPerVector_NPerBlock + +#### Special NGCHW/NGCDHW Layout Constraints +For NGCHW/NGCDHW layouts requiring transpose: +```cpp +(G * C) % CDEBlockTransferScalarPerVector_NPerBlock == 0 +(G * K) % CDEBlockTransferScalarPerVector_NPerBlock == 0 +input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0 +output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0 +``` +- **Rule**: G×C must be divisible by CDEBlockTransferScalarPerVector_NPerBlock +- **Rule**: G×K must be divisible by CDEBlockTransferScalarPerVector_NPerBlock +- **Rule**: Product of input spatial dimensions must be divisible by CDEBlockTransferScalarPerVector_NPerBlock +- **Rule**: Product of output spatial dimensions must be divisible by CDEBlockTransferScalarPerVector_NPerBlock + +#### Descriptor Size Constraints +```cpp +a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= 2GB +b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= 2GB +c_grid_desc.GetElementSpaceSize() * sizeof(CDataType) <= 2GB +``` +- **Rule**: Each tensor descriptor must represent less than 2GB of data + +#### Device-Specific Constraints +- On **gfx908**: AccDataType must be `float` or `int32_t` +- **DirectLoad** mode: Only supported on gfx950 + +--- + +## 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + +### Gridwise Implementation +Uses: +- `GridwiseGemmMultipleABD_xdl_cshuffle` (when isMultiA || isMultiB) +- `GridwiseGemmMultipleD_xdl_cshuffle` (otherwise) + +### Compile-Time Static Asserts (Gridwise Level) + +Same as V3 version: +```cpp +static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + +static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + +static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); +``` +- **Rule**: MPerBlock must be divisible by (MPerXdl × MXdlPerWave) +- **Rule**: NPerBlock must be divisible by (NXdlPerWave × NPerXdl) +- **Rule**: KPerBlock must be divisible by both AK1Value and BK1Value +- **Rule**: MXdlPerWave must be divisible by CShuffleMXdlPerWavePerShuffle +- **Rule**: NXdlPerWave must be divisible by CShuffleNXdlPerWavePerShuffle + +### Runtime Validation Checks + +#### Vector Access for A (Input) Tensor +For standard layouts (G_NW_C, G_NHW_C, etc.): +```cpp +C % ABlockTransferSrcScalarPerVector == 0 // When ABlockTransferSrcVectorDim == 2 +``` +- **Rule**: C must be divisible by ABlockTransferSrcScalarPerVector + +Alternative for grouped layouts with C==1 or NumGroupsToMerge==1: +```cpp +G % ABlockTransferSrcScalarPerVector == 0 // When ABlockTransferSrcVectorDim == 1 +``` +- **Rule**: G must be divisible by ABlockTransferSrcScalarPerVector when accessing per G dimension + +For NGCHW/NGCDHW layouts without transpose: +```cpp +input_spatial_acum % ABlockTransferSrcScalarPerVector == 0 // When ABlockTransferSrcVectorDim == 1 +``` +- **Rule**: Product of input spatial dimensions must be divisible by ABlockTransferSrcScalarPerVector + +#### Vector Access for B (Weight) Tensor +```cpp +C % BBlockTransferSrcScalarPerVector == 0 // When BBlockTransferSrcVectorDim == 2 +``` +- **Rule**: C must be divisible by BBlockTransferSrcScalarPerVector + +#### Vector Access for D Tensors +For each D tensor with layouts G_NW_K, G_NHW_K, etc.: +```cpp +K % CDEBlockTransferScalarPerVector_NPerBlock == 0 +``` +- **Rule**: K must be divisible by CDEBlockTransferScalarPerVector_NPerBlock +- **Rule**: D and E tensors must have identical shapes (all dimensions must match) + +#### Vector Access for E (Output) Tensor +For standard layouts: +```cpp +K % CDEBlockTransferScalarPerVector_NPerBlock == 0 // When CTranspose == false +``` +For transposed layouts: +```cpp +output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0 +``` + +#### Transpose Kernel Requirements +For NGCHW/NGCDHW layouts with transpose: +```cpp +(G * C) % CDEBlockTransferScalarPerVector_NPerBlock == 0 +(G * K) % CDEBlockTransferScalarPerVector_NPerBlock == 0 +input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0 +output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock == 0 +``` +- Workspace pointer must be allocated + +#### NumGroupsToMerge Constraints +When NumGroupsToMerge > 1: +```cpp +C == 1 +G % NumGroupsToMerge == 0 +``` +- **Rule**: C must equal 1 +- **Rule**: G must be divisible by NumGroupsToMerge + +#### Tensor Size Constraints +```cpp +a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= 2GB +b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= 2GB +e_grid_desc.GetElementSpaceSize() * sizeof(EDataType) <= 2GB +``` + +--- + +## 3. DeviceGroupedConvFwdMultipleD_Wmma_CShuffle + +### Gridwise Implementation +Uses: `GridwiseGemmMultipleD_Wmma` + +### Compile-Time Static Asserts + +```cpp +static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + +static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!"); +static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!"); + +static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize"); +``` +- **Rule**: MPerBlock must be divisible by (MPerWmma × MRepeat) +- **Rule**: NPerBlock must be divisible by (NRepeat × NPerWmma) +- **Rule**: KPack must be divisible by (A_K1 × A_KRow) where A_KRow = 2 +- **Rule**: KPack must be divisible by (B_K1 × B_KRow) where B_KRow = 2 +- **Rule**: BlockSize must equal MWaves × NWaves × WaveSize + - Where: MWaves = MPerBlock / (MRepeat × MPerWmma) + - Where: NWaves = NPerBlock / (NRepeat × NPerWmma) + +### Derived Constraints +```cpp +K % K1 == 0 // Asserted in MakeAGridDescriptor and MakeBGridDescriptor +KPack = math::integer_least_multiple(K1, WmmaK) // Where WmmaK = 16 +``` +- **Rule**: K must be divisible by K1 +- **Rule**: KPack must be at least lcm(K1, 16) + +### Runtime Validation Checks + +#### Device Support +```cpp +ck::is_gfx11_supported() || ck::is_gfx12_supported() +``` +- **Rule**: Only supports gfx11 and gfx12 architectures +- **Rule**: On these devices, AccDataType must be `float` or `int32_t` + +#### Vector Access for A +For layouts G_NW_C, G_NHW_C, G_NDHW_C, GNWC, GNHWC, GNDHWC, NWGC, NHWGC, NDHWGC: +```cpp +C % ABlockTransferSrcScalarPerVector == 0 // When ABlockTransferSrcVectorDim == 2 +``` + +#### Vector Access for B +For layouts G_K_X_C, G_K_YX_C, G_K_ZYX_C, GKXC, GKYXC, GKZYXC, KXGC, KYXGC, KZYXGC: +```cpp +C % BBlockTransferSrcScalarPerVector == 0 // When BBlockTransferSrcVectorDim == 2 +``` + +#### Vector Access for D and E +For all D tensors and E tensor: +```cpp +K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0 +``` + +--- + +## 4. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor + +### Gridwise Implementation +Uses: `GridwiseGemmMultipleD_xdl_cshuffle` + +### Compile-Time Static Asserts + +Same as other XDL-based operations: +```cpp +static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + +static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + +static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); +``` + +### Runtime Validation Checks + +#### Tensor Splitting Validation +This operation splits large tensors that exceed 2GB: +```cpp +is_split_valid_ && gemms_count_ == valid_gemms_count_ +``` +- **Rule**: The tensor splitting algorithm must successfully partition the problem into sub-problems < 2GB + +#### D and E Tensor Matching +```cpp +ds_g_n_k_wos_strides_[i] == e_g_n_k_wos_strides_ +ds_g_n_k_wos_lengths_[i] == e_g_n_k_wos_lengths_ +``` +- **Rule**: All D tensors must have identical strides and lengths to E tensor + +#### Vector Access Constraints +Same as standard XDL operations: +```cpp +C % ABlockTransferSrcScalarPerVector == 0 +C % BBlockTransferSrcScalarPerVector == 0 +K % CDEBlockTransferScalarPerVector_NPerBlock == 0 +``` + +--- + +## 5. DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK + +### Gridwise Implementation +Uses: `GridwiseGemmDlMultipleD_km_kn_mn` + +### Compile-Time Static Asserts + +From `BlockwiseGemmDl_v2r3`: +```cpp +static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); +static_assert(BM0 == 2 && BN0 == 2, "wrong"); +static_assert(BlockSize == BM101 * BM100 * BN101 * BN100, + "wrong! blocksize and cluster size not consistent"); +``` +- **Rule**: BM (MPerBlock) must be divisible by BM1 +- **Rule**: BN (NPerBlock) must be divisible by BN1 +- **Rule**: BM0 must equal 2 +- **Rule**: BN0 must equal 2 +- **Rule**: BlockSize must equal the product of thread cluster dimensions + +### Runtime Validation Checks + +#### Device Support +```cpp +ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || +ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported() +``` +- **Rule**: Must be one of: gfx906, gfx103, gfx11, gfx12, or support XDL instructions + +#### Vector Transfer Constraints for A +```cpp +srcVectorLengths[I1] == 1 && srcVectorLengths[I2] == 1 +K1 % srcVectorLengths[I3] == 0 +K0PerBlock % srcVectorLengths[I0] == 0 +C % (srcVectorLengths[I0] * srcVectorLengths[I3]) == 0 +``` +- **Rule**: Vector lengths for M dimensions must be 1 +- **Rule**: K1 must be divisible by K1 vector length +- **Rule**: K0PerBlock must be divisible by K0 vector length +- **Rule**: C must be divisible by the product of K0 and K1 vector lengths + +#### Vector Transfer Constraints for B +Same structure as A: +```cpp +srcVectorLengths[I1] == 1 && srcVectorLengths[I2] == 1 +K1 % srcVectorLengths[I3] == 0 +K0PerBlock % srcVectorLengths[I0] == 0 +C % (srcVectorLengths[I0] * srcVectorLengths[I3]) == 0 +``` + +#### Vector Access for E (Output) +```cpp +K % CThreadTransferDstScalarPerVector == 0 +CThreadTransferSrcDstVectorDim == 5 +``` +- **Rule**: K must be divisible by CThreadTransferDstScalarPerVector +- **Rule**: Vector dimension must be 5 (the K dimension) + +#### Tile Size Constraints +```cpp +M % MPerBlock == 0 +N % NPerBlock == 0 +K0 % K0PerBlock == 0 +``` +- **Rule**: M must be divisible by MPerBlock +- **Rule**: N must be divisible by NPerBlock +- **Rule**: K0 must be divisible by K0PerBlock + +--- + +## Common Rules Across All Operations + +### Specialization Requirements + +For **Filter1x1Stride1Pad0** specialization: +```cpp +FilterSpatialDim == 1 +ConvStride == 1 +LeftPad == 0 +RightPad == 0 +``` +- Must be true for all spatial dimensions + +For **Filter1x1Pad0** specialization: +```cpp +FilterSpatialDim == 1 +LeftPad == 0 +RightPad == 0 +``` +- Must be true for all spatial dimensions + +For **Filter3x3** specialization: +```cpp +C == 1 +FilterSpatialDim == 3 +``` +- Must be true for all spatial dimensions + +### Pipeline Stage Constraints + +For non-v1 pipeline versions: +```cpp +num_k_loop > PrefetchStages +``` +- **Rule**: Number of K-blocks must exceed the number of prefetch stages + +### TF32 Support Constraints +```cpp +is_same_v // When using TF32 +is_tf32_supported() // Device must support TF32 +``` +- **Rule**: When using TF32, A and B compute data types must match +- **Rule**: Device must have TF32 support + +### XDL/WMMA Support Validation +```cpp +ck::is_xdl_wmma_supported() +``` +- **Rule**: The combination of data types and XDL/WMMA tile sizes must be supported by the device + +--- + +## Summary of Key Parameter Relationships + +### Block-Level Tiling +``` +MPerBlock = MPerXdl × MXdlPerWave × MWaves +NPerBlock = NPerXdl × NXdlPerWave × NWaves +BlockSize = MWaves × NWaves × WaveSize +``` + +For WMMA: +``` +MPerBlock = MPerWmma × MRepeat × MWaves +NPerBlock = NPerWmma × NRepeat × NWaves +``` + +### K-Dimension Decomposition +``` +K = AK0 × AK1 = BK0 × BK1 +KPerBlock = AK0PerBlock × AK1 = BK0PerBlock × BK1 +``` + +### Shuffle Constraints +``` +MXdlPerWave = N × CShuffleMXdlPerWavePerShuffle (N is integer) +NXdlPerWave = M × CShuffleNXdlPerWavePerShuffle (M is integer) +``` + +### Vector Access Hierarchy +1. **Data must be aligned** to vector access size +2. **Dimensions accessed vectorially** must be divisible by ScalarPerVector +3. **Different layouts** have different vectorizable dimensions: + - Row-major A: vectorize K dimension + - Column-major A: vectorize M dimension + - Row-major B: vectorize N dimension + - Column-major B: vectorize K dimension + +### LDS Padding +``` +ABlockLdsExtraM: Padding for A matrix in LDS to avoid bank conflicts +BBlockLdsExtraN: Padding for B matrix in LDS to avoid bank conflicts +``` +- Often set to 1 on gfx950 to reduce bank conflicts + +--- + +## Implementation Notes + +1. **Hierarchy**: Device ops compose gridwise ops, which compose blockwise ops, which compose threadwise ops +2. **Memory Flow**: Global → LDS → Register (VGPR) → Compute → Register → LDS → Global +3. **Direct Load**: Some implementations support direct global-to-register load (bypassing LDS) on gfx950 +4. **Pipeline Versions**: Different pipeline versions (v1, v2, v3, v4, v5) have different prefetch and scheduling strategies +5. **Multi-AB Support**: Some operations support multiple A/B input tensors (tuples) +6. **Transpose Support**: Some layouts require intermediate transpose operations with workspace allocation + +--- + +## Validation Checklist for Template Parameter Selection + +When selecting template parameters, verify: + +- [ ] MPerBlock % (MPerXdl × MXdlPerWave) == 0 +- [ ] NPerBlock % (NPerXdl × NXdlPerWave) == 0 +- [ ] KPerBlock % AK1 == 0 and KPerBlock % BK1 == 0 +- [ ] BlockSize == computed_from_waves_and_wave_size +- [ ] All channel/spatial dimensions divisible by respective ScalarPerVector values +- [ ] Tensor descriptors < 2GB each +- [ ] Correct device architecture and data type support +- [ ] Specialization requirements met (filter size, stride, padding) +- [ ] Shuffle parameters properly divide wave parameters +- [ ] Pipeline stage requirements met for chosen version +- [ ] Workspace allocated if using transpose kernels diff --git a/experimental/builder/validation_rules/template_parameter_validation_guide.md b/experimental/builder/validation_rules/template_parameter_validation_guide.md new file mode 100644 index 0000000000..7787376b3c --- /dev/null +++ b/experimental/builder/validation_rules/template_parameter_validation_guide.md @@ -0,0 +1,825 @@ +# Template Parameter Validation Guide for Forward Convolution Device Operations + +This guide maps each template parameter to its constraint rules, enabling upstream validation before instantiation. + +## Quick Reference: Template Parameter Names + +### Common XDL-based Parameters +- **BlockSize**: Number of threads per block +- **MPerBlock, NPerBlock, KPerBlock**: Block tile sizes for GEMM dimensions +- **AK1, BK1**: K-dimension decomposition (K = K0 × K1) +- **MPerXDL, NPerXDL**: XDL/MFMA instruction tile size +- **MXdlPerWave, NXdlPerWave**: Number of XDL tiles per wave +- **ABlockTransferSrcScalarPerVector**: Vector width for loading A matrix +- **BBlockTransferSrcScalarPerVector**: Vector width for loading B matrix +- **ABlockTransferDstScalarPerVector_AK1**: Vector width for storing A to LDS +- **BBlockTransferDstScalarPerVector_BK1**: Vector width for storing B to LDS +- **CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle**: C-shuffle granularity +- **CDEBlockTransferScalarPerVector_NPerBlock**: Vector width for C/D/E transfers +- **ABlockLdsExtraM, BBlockLdsExtraN**: LDS padding to avoid bank conflicts + +### WMMA-specific Parameters +- **MPerWmma, NPerWmma**: WMMA instruction tile size (typically 16×16) +- **K1**: K-dimension granularity +- **MRepeat, NRepeat**: Number of WMMA tiles per wave +- **CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle**: C-shuffle granularity + +### DL-specific Parameters +- **K0PerBlock**: K0 dimension of block tile +- **K1**: K1 value (typically 4 or 8) +- **M1PerThread, N1PerThread**: Thread tile size +- **KPerThread**: K dimension per thread + +--- + +## 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + +### Template Declaration +```cpp +template < + index_t NDimSpatial, + typename ALayout, typename BLayout, typename DsLayout, typename ELayout, + typename ADataType, typename BDataType, typename AccDataType, + typename CShuffleDataType, typename DsDataType, typename EDataType, + typename AElementwiseOperation, typename BElementwiseOperation, + typename CDEElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, + index_t AK1, index_t BK1, + index_t MPerXDL, index_t NPerXDL, + index_t MXdlPerWave, index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_AK1, + index_t ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_BK1, + index_t BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CDEBlockTransferScalarPerVector_NPerBlock, + BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, + typename AComputeDataType = ..., + typename BComputeDataType = AComputeDataType, + bool DirectLoad = false +> +``` + +### Compile-Time Constraints (Can Check Before Instantiation) + +#### Rule 1: Block Tiling Divisibility +```cpp +CONSTRAINT: MPerBlock % (MPerXDL * MXdlPerWave) == 0 +PARAMETERS: MPerBlock, MPerXDL, MXdlPerWave +``` + +#### Rule 2: Block Tiling Divisibility (N dimension) +```cpp +CONSTRAINT: NPerBlock % (NXdlPerWave * NPerXDL) == 0 +PARAMETERS: NPerBlock, NPerXDL, NXdlPerWave +``` + +#### Rule 3: K-dimension Decomposition +```cpp +CONSTRAINT: KPerBlock % AK1 == 0 +CONSTRAINT: KPerBlock % BK1 == 0 +PARAMETERS: KPerBlock, AK1, BK1 +``` + +#### Rule 4: Shuffle Granularity (M dimension) +```cpp +CONSTRAINT: MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 +PARAMETERS: MXdlPerWave, CShuffleMXdlPerWavePerShuffle +``` + +#### Rule 5: Shuffle Granularity (N dimension) +```cpp +CONSTRAINT: NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0 +PARAMETERS: NXdlPerWave, CShuffleNXdlPerWavePerShuffle +``` + +#### Rule 6: Derived - Wave Count and BlockSize +```cpp +DERIVED: MWaves = MPerBlock / (MPerXDL * MXdlPerWave) +DERIVED: NWaves = NPerBlock / (NPerXDL * NXdlPerWave) +DERIVED: WaveSize = 64 (or 32 for some architectures) +CONSTRAINT: BlockSize == MWaves * NWaves * WaveSize +PARAMETERS: BlockSize, MPerBlock, NPerBlock, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave +``` + +#### Rule 7: Derived - KPerThread and KPack +```cpp +DERIVED: KPack = max(lcm(AK1, BK1), MFMA_k_per_blk) +DERIVED: KPerThread = KPerBlock / (KPack) +CONSTRAINT: KPerThread % KPack == 0 +PARAMETERS: KPerBlock, AK1, BK1, MPerXDL, NPerXDL (MFMA size affects KPack) +``` + +#### Rule 8: XDL Support +```cpp +CONSTRAINT: is_xdl_wmma_supported() +PARAMETERS: AComputeDataType, BComputeDataType, MPerXDL, NPerXDL +``` + +#### Rule 9: TF32 Constraints +```cpp +CONSTRAINT: IF (AComputeDataType == tf32 OR BComputeDataType == tf32) + THEN AComputeDataType == BComputeDataType +PARAMETERS: AComputeDataType, BComputeDataType +``` + +#### Rule 10: DirectLoad Limitation +```cpp +CONSTRAINT: IF DirectLoad == true THEN device == "gfx950" +CONSTRAINT: IF DirectLoad == true THEN + AElementwiseOperation == PassThrough AND + BElementwiseOperation == PassThrough +PARAMETERS: DirectLoad, AElementwiseOperation, BElementwiseOperation +``` + +### Upstream Validation Function Template + +```cpp +template +struct TemplateParameterValidator { + static constexpr bool IsValid() { + // Rule 1: MPerBlock divisibility + if constexpr(DeviceOp::MPerBlock % + (DeviceOp::MPerXDL * DeviceOp::MXdlPerWave) != 0) + return false; + + // Rule 2: NPerBlock divisibility + if constexpr(DeviceOp::NPerBlock % + (DeviceOp::NXdlPerWave * DeviceOp::NPerXDL) != 0) + return false; + + // Rule 3: KPerBlock divisibility + if constexpr(DeviceOp::KPerBlock % DeviceOp::AK1 != 0 || + DeviceOp::KPerBlock % DeviceOp::BK1 != 0) + return false; + + // Rule 4-5: Shuffle constraints + if constexpr(DeviceOp::MXdlPerWave % DeviceOp::CShuffleMXdlPerWavePerShuffle != 0) + return false; + if constexpr(DeviceOp::NXdlPerWave % DeviceOp::CShuffleNXdlPerWavePerShuffle != 0) + return false; + + // Rule 6: BlockSize validation + constexpr auto MWaves = DeviceOp::MPerBlock / + (DeviceOp::MPerXDL * DeviceOp::MXdlPerWave); + constexpr auto NWaves = DeviceOp::NPerBlock / + (DeviceOp::NPerXDL * DeviceOp::NXdlPerWave); + constexpr auto WaveSize = 64; // Adjust based on architecture + if constexpr(DeviceOp::BlockSize != MWaves * NWaves * WaveSize) + return false; + + // Additional checks... + return true; + } +}; +``` + +--- + +## 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + +### Template Declaration +```cpp +template < + index_t NDimSpatial, + typename ALayout, typename BLayout, typename DsLayout, typename ELayout, + typename ADataType, typename BDataType, typename AccDataType, + typename CShuffleDataType, typename DsDataType, typename EDataType, + typename AElementwiseOperation, typename BElementwiseOperation, + typename CDEElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + GemmSpecialization GemmSpec, + index_t NumGemmKPrefetchStage, + index_t BlockSize, + index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, + index_t AK1, index_t BK1, + index_t MPerXDL, index_t NPerXDL, + index_t MXdlPerWave, index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_AK1, + index_t ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_BK1, + index_t BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CDEBlockTransferScalarPerVector_NPerBlock, + typename AComputeDataType = ..., + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler(), + index_t NumGroupsToMerge = 1 +> +``` + +### Compile-Time Constraints + +Same as V3 (Rules 1-9 above) plus: + +#### Rule 11: NumGroupsToMerge (compile-time checkable when C=1 is known) +```cpp +CONSTRAINT: IF NumGroupsToMerge > 1 THEN must use specific layouts +PARAMETERS: NumGroupsToMerge, ALayout, BLayout, ELayout +REQUIRED_LAYOUTS: NSpatialGC_GKSpatial_NSpatialGK OR + NGCSpatial_GKSpatial_NGKSpatial OR + NGCHW_NGKHW OR NGCDHW_NGKDHW +``` + +--- + +## 3. DeviceGroupedConvFwdMultipleD_Wmma_CShuffle + +### Template Declaration +```cpp +template < + index_t NDimSpatial, + typename ALayout, typename BLayout, typename DsLayout, typename ELayout, + typename ADataType, typename BDataType, typename AccDataType, + typename CShuffleDataType, typename DsDataType, typename EDataType, + typename AElementwiseOperation, typename BElementwiseOperation, + typename CDEElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + GemmSpecialization GemmSpec, + index_t NumGemmKPrefetchStage, + index_t BlockSize, + index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, + index_t K1, + index_t MPerWmma, index_t NPerWmma, + index_t MRepeat, index_t NRepeat, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_AK1, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_BK1, + bool BBlockLdsExtraN, + index_t CShuffleMRepeatPerShuffle, + index_t CShuffleNRepeatPerShuffle, + typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler(), + PipelineVersion PipelineVer = PipelineVersion::v1 +> +``` + +### Compile-Time Constraints + +#### Rule 12: WMMA Block Tiling (M dimension) +```cpp +CONSTRAINT: MPerBlock % (MPerWmma * MRepeat) == 0 +PARAMETERS: MPerBlock, MPerWmma, MRepeat +``` + +#### Rule 13: WMMA Block Tiling (N dimension) +```cpp +CONSTRAINT: NPerBlock % (NPerWmma * NRepeat) == 0 +PARAMETERS: NPerBlock, NPerWmma, NRepeat +``` + +#### Rule 14: WMMA Wave Count and BlockSize +```cpp +DERIVED: MWaves = MPerBlock / (MPerWmma * MRepeat) +DERIVED: NWaves = NPerBlock / (NPerWmma * NRepeat) +DERIVED: WaveSize = 32 (for WMMA architectures) +CONSTRAINT: BlockSize == MWaves * NWaves * WaveSize +PARAMETERS: BlockSize, MPerBlock, NPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat +``` + +#### Rule 15: WMMA KPack Constraints +```cpp +DERIVED: WmmaK = (K1 == 16) ? 32 : 16 +DERIVED: KPack = math::integer_least_multiple(K1, WmmaK) +CONSTRAINT: KPack % (K1 * 2) == 0 // A_KRow = 2 +CONSTRAINT: KPack % (K1 * 2) == 0 // B_KRow = 2 +PARAMETERS: K1, KPerBlock +NOTE: KPerBlock should be chosen such that resulting KPack satisfies these +``` + +#### Rule 16: Device Architecture for WMMA +```cpp +CONSTRAINT: is_gfx11_supported() OR is_gfx12_supported() +CONSTRAINT: AccDataType == float OR AccDataType == int32_t +PARAMETERS: AccDataType +DEVICE: Must be gfx11 or gfx12 +``` + +--- + +## 4. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor + +### Template Declaration +Same as standard XDL version (similar to #2 but without NumGroupsToMerge) + +### Compile-Time Constraints +Rules 1-10 apply (same as DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle) + +### Special Consideration +This operation handles tensors > 2GB by splitting, so descriptor size constraints are managed internally through splitting algorithm. + +--- + +## 5. DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK + +### Template Declaration +```cpp +template < + index_t NDimSpatial, + typename ADataType, typename BDataType, + typename DsDataType, typename EDataType, typename AccDataType, + typename ALayout, typename BLayout, typename DsLayout, typename ELayout, + typename AElementwiseOperation, typename BElementwiseOperation, + typename CDEElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, index_t NPerBlock, + index_t K0PerBlock, index_t K1, + index_t M1PerThread, index_t N1PerThread, index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector +> +``` + +### Compile-Time Constraints + +#### Rule 17: DL Thread Cluster Constraints +```cpp +DERIVED: BM1 = M1PerThread +DERIVED: BN1 = N1PerThread +DERIVED: BM = MPerBlock +DERIVED: BN = NPerBlock +CONSTRAINT: BM % BM1 == 0 +CONSTRAINT: BN % BN1 == 0 +PARAMETERS: MPerBlock, NPerBlock, M1PerThread, N1PerThread +``` + +#### Rule 18: DL Grid Decomposition +```cpp +CONSTRAINT: BM0 == 2 +CONSTRAINT: BN0 == 2 +DERIVED: BM0, BN0 are derived from thread cluster configuration +PARAMETERS: M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs +NOTE: Thread cluster must result in BM0=2, BN0=2 +``` + +#### Rule 19: DL BlockSize Validation +```cpp +DERIVED: BM101, BM100, BN101, BN100 from thread cluster +CONSTRAINT: BlockSize == BM101 * BM100 * BN101 * BN100 +PARAMETERS: BlockSize, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs +``` + +#### Rule 20: DL Vector Transfer Dimensions +```cpp +CONSTRAINT: ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1[I1] == 1 +CONSTRAINT: ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1[I2] == 1 +CONSTRAINT: BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1[I1] == 1 +CONSTRAINT: BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1[I2] == 1 +PARAMETERS: ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 +``` + +#### Rule 21: DL Output Vector Dimension +```cpp +CONSTRAINT: CThreadTransferSrcDstVectorDim == 5 +PARAMETERS: CThreadTransferSrcDstVectorDim +``` + +--- + +## Consolidated Validation Matrix + +### For XDL-based Operations (V3, Standard, Large_Tensor) + +| Parameter | Constraint | Depends On | Rule # | +|-----------|------------|------------|--------| +| MPerBlock | % (MPerXDL × MXdlPerWave) == 0 | MPerXDL, MXdlPerWave | 1 | +| NPerBlock | % (NXdlPerWave × NPerXDL) == 0 | NPerXDL, NXdlPerWave | 2 | +| KPerBlock | % AK1 == 0 AND % BK1 == 0 | AK1, BK1 | 3 | +| MXdlPerWave | % CShuffleMXdlPerWavePerShuffle == 0 | CShuffleMXdlPerWavePerShuffle | 4 | +| NXdlPerWave | % CShuffleNXdlPerWavePerShuffle == 0 | CShuffleNXdlPerWavePerShuffle | 5 | +| BlockSize | == MWaves × NWaves × WaveSize | All M/N tiling params | 6 | +| KPerThread | % KPack == 0 | KPerBlock, AK1, BK1 | 7 | +| MPerXDL, NPerXDL | is_xdl_supported(...) | AComputeDataType, BComputeDataType | 8 | +| AComputeDataType | == BComputeDataType (if TF32) | BComputeDataType | 9 | +| DirectLoad | Requires gfx950 & PassThrough ops | Device, ElementwiseOps | 10 | + +### For WMMA-based Operations + +| Parameter | Constraint | Depends On | Rule # | +|-----------|------------|------------|--------| +| MPerBlock | % (MPerWmma × MRepeat) == 0 | MPerWmma, MRepeat | 12 | +| NPerBlock | % (NPerWmma × NRepeat) == 0 | NPerWmma, NRepeat | 13 | +| BlockSize | == MWaves × NWaves × WaveSize | All M/N tiling params | 14 | +| KPack | % (K1 × 2) == 0 (for A and B) | K1 | 15 | +| AccDataType | == float OR == int32_t | - | 16 | +| Device | gfx11 or gfx12 only | - | 16 | + +### For DL-based Operations + +| Parameter | Constraint | Depends On | Rule # | +|-----------|------------|------------|--------| +| MPerBlock | % M1PerThread == 0 | M1PerThread | 17 | +| NPerBlock | % N1PerThread == 0 | N1PerThread | 17 | +| Thread Cluster | Must result in BM0=2, BN0=2 | M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs | 18 | +| BlockSize | == BM101 × BM100 × BN101 × BN100 | Thread cluster config | 19 | +| Vector Lengths | M/N dimensions == 1 | ABlockTransferSrcVectorTensorLengths | 20 | +| CThreadTransferSrcDstVectorDim | == 5 | - | 21 | + +--- + +## Upstream Validation Strategy + +### Phase 1: Compile-Time Template Validation + +Create a validation struct that can be checked at compile-time: + +```cpp +template +struct DeviceOpTemplateValidator { + // Extract template parameters as constexpr values + static constexpr auto BlockSize = DeviceOp::BlockSize; + static constexpr auto MPerBlock = DeviceOp::MPerBlock; + static constexpr auto NPerBlock = DeviceOp::NPerBlock; + static constexpr auto KPerBlock = DeviceOp::KPerBlock; + // ... extract all other parameters + + // Validate each rule + static constexpr bool ValidateRule1() { + if constexpr(is_xdl_based) { + return MPerBlock % (MPerXDL * MXdlPerWave) == 0; + } + return true; // N/A for non-XDL + } + + static constexpr bool ValidateRule2() { + if constexpr(is_xdl_based) { + return NPerBlock % (NXdlPerWave * NPerXDL) == 0; + } + return true; + } + + // ... all rules + + static constexpr bool IsValid() { + return ValidateRule1() && ValidateRule2() && /* ... all rules */; + } + + // Optional: Generate error messages + static constexpr const char* GetErrorMessage() { + if constexpr(!ValidateRule1()) + return "MPerBlock not divisible by (MPerXDL * MXdlPerWave)"; + if constexpr(!ValidateRule2()) + return "NPerBlock not divisible by (NXdlPerWave * NPerXDL)"; + // ... etc + return "Valid"; + } +}; + +// Usage: +static_assert(DeviceOpTemplateValidator::IsValid(), + DeviceOpTemplateValidator::GetErrorMessage()); +``` + +### Phase 2: Runtime Problem-Size Validation + +After template instantiation, validate against actual problem dimensions: + +```cpp +struct RuntimeValidator { + // Check vector access divisibility + static bool CheckVectorAccess( + index_t C, index_t K, index_t G, + index_t ABlockTransferSrcScalarPerVector, + index_t BBlockTransferSrcScalarPerVector, + index_t CDEBlockTransferScalarPerVector_NPerBlock, + const Layout& layouts) + { + // Check C divisibility for A and B + if (layouts.AUsesChannelVectorization()) + if (C % ABlockTransferSrcScalarPerVector != 0) + return false; + + if (layouts.BUsesChannelVectorization()) + if (C % BBlockTransferSrcScalarPerVector != 0) + return false; + + // Check K divisibility for E + if (K % CDEBlockTransferScalarPerVector_NPerBlock != 0) + return false; + + return true; + } + + // Check tile size divisibility + static bool CheckTileSizes( + index_t M, index_t N, index_t K, + index_t MPerBlock, index_t NPerBlock, index_t KPerBlock) + { + return (M % MPerBlock == 0) && + (N % NPerBlock == 0) && + (K % KPerBlock == 0); + } + + // Check descriptor size limits + static bool CheckDescriptorSizes( + index_t M, index_t N, index_t K, + size_t ADataTypeSize, size_t BDataTypeSize, size_t EDataTypeSize) + { + constexpr long_index_t TwoGB = (1L << 31); + return (M * K * ADataTypeSize <= TwoGB) && + (N * K * BDataTypeSize <= TwoGB) && + (M * N * EDataTypeSize <= TwoGB); + } +}; +``` + +--- + +## Complete Validation Code Template + +```cpp +#pragma once +#include + +namespace ck { +namespace validation { + +// Trait to detect operation type +template struct is_xdl_based : std::false_type {}; +template struct is_wmma_based : std::false_type {}; +template struct is_dl_based : std::false_type {}; + +// Specialize for each operation type... + +template +struct TemplateParameterValidator { + // Extract parameters + static constexpr auto BlockSize = DeviceOp::BlockSize; + static constexpr auto MPerBlock = DeviceOp::MPerBlock; + static constexpr auto NPerBlock = DeviceOp::NPerBlock; + static constexpr auto KPerBlock = DeviceOp::KPerBlock; + + // XDL-specific validation + template + static constexpr bool ValidateXDL() { + if constexpr(is_xdl_based::value) { + constexpr auto MPerXDL = T::MPerXDL; + constexpr auto NPerXDL = T::NPerXDL; + constexpr auto MXdlPerWave = T::MXdlPerWave; + constexpr auto NXdlPerWave = T::NXdlPerWave; + constexpr auto AK1 = T::AK1; + constexpr auto BK1 = T::BK1; + constexpr auto CShuffleMXdlPerWavePerShuffle = T::CShuffleMXdlPerWavePerShuffle; + constexpr auto CShuffleNXdlPerWavePerShuffle = T::CShuffleNXdlPerWavePerShuffle; + + // Rule 1 + if constexpr(MPerBlock % (MPerXDL * MXdlPerWave) != 0) + return false; + + // Rule 2 + if constexpr(NPerBlock % (NXdlPerWave * NPerXDL) != 0) + return false; + + // Rule 3 + if constexpr(KPerBlock % AK1 != 0 || KPerBlock % BK1 != 0) + return false; + + // Rule 4 + if constexpr(MXdlPerWave % CShuffleMXdlPerWavePerShuffle != 0) + return false; + + // Rule 5 + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + return false; + + // Rule 6 + constexpr auto MWaves = MPerBlock / (MPerXDL * MXdlPerWave); + constexpr auto NWaves = NPerBlock / (NPerXDL * NXdlPerWave); + constexpr auto WaveSize = 64; // Adjust based on arch + if constexpr(BlockSize != MWaves * NWaves * WaveSize) + return false; + + return true; + } + return true; // N/A + } + + // WMMA-specific validation + template + static constexpr bool ValidateWMMA() { + if constexpr(is_wmma_based::value) { + constexpr auto MPerWmma = T::MPerWmma; + constexpr auto NPerWmma = T::NPerWmma; + constexpr auto MRepeat = T::MRepeat; + constexpr auto NRepeat = T::NRepeat; + constexpr auto K1 = T::K1; + + // Rule 12 + if constexpr(MPerBlock % (MPerWmma * MRepeat) != 0) + return false; + + // Rule 13 + if constexpr(NPerBlock % (NPerWmma * NRepeat) != 0) + return false; + + // Rule 14 + constexpr auto MWaves = MPerBlock / (MPerWmma * MRepeat); + constexpr auto NWaves = NPerBlock / (NPerWmma * NRepeat); + constexpr auto WaveSize = 32; + if constexpr(BlockSize != MWaves * NWaves * WaveSize) + return false; + + // Rule 15 - KPack constraints + constexpr auto WmmaK = (K1 == 16) ? 32 : 16; + // KPack computation is complex, may need runtime check + + return true; + } + return true; // N/A + } + + // DL-specific validation + template + static constexpr bool ValidateDL() { + if constexpr(is_dl_based::value) { + constexpr auto M1PerThread = T::M1PerThread; + constexpr auto N1PerThread = T::N1PerThread; + + // Rule 17 + if constexpr(MPerBlock % M1PerThread != 0) + return false; + if constexpr(NPerBlock % N1PerThread != 0) + return false; + + // Rules 18-19 require analyzing thread cluster which is complex + // May need partial compile-time, partial runtime validation + + return true; + } + return true; // N/A + } + + // Master validation + static constexpr bool IsValid() { + return ValidateXDL() && + ValidateWMMA() && + ValidateDL(); + } +}; + +// Runtime validation for problem-dependent constraints +template +struct RuntimeParameterValidator { + static bool Validate( + index_t G, index_t N, index_t C, index_t K, + const std::array& spatial_dims_in, + const std::array& spatial_dims_out) + { + // Calculate spatial products + index_t input_spatial_acum = 1; + index_t output_spatial_acum = 1; + for(index_t i = 0; i < DeviceOp::NDimSpatial; ++i) { + input_spatial_acum *= spatial_dims_in[i]; + output_spatial_acum *= spatial_dims_out[i]; + } + + // Check vector access based on layout + if constexpr(DeviceOp::UsesChannelVectorForA()) { + if(C % DeviceOp::ABlockTransferSrcScalarPerVector != 0) + return false; + } + + if constexpr(DeviceOp::UsesChannelVectorForB()) { + if(C % DeviceOp::BBlockTransferSrcScalarPerVector != 0) + return false; + } + + if(K % DeviceOp::CDEBlockTransferScalarPerVector_NPerBlock != 0) + return false; + + // Check tile divisibility + const auto MPadded = DeviceOp::CalculateMPadded(/* M calculated from problem */); + const auto NPadded = DeviceOp::CalculateNPadded(/* N calculated from problem */); + const auto KPadded = DeviceOp::CalculateKPadded(/* K calculated from problem */); + + // Additional checks based on specialization... + + return true; + } +}; + +} // namespace validation +} // namespace ck +``` + +--- + +## Usage Example + +```cpp +// At compile-time (before instantiation) +using MyDeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + /* ... layouts ... */ + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 2, // MXdlPerWave + 2, // NXdlPerWave + /* ... other params ... */ +>; + +// Compile-time validation +static_assert(ck::validation::TemplateParameterValidator::IsValid(), + "Invalid template parameters for device operation!"); + +// At runtime (when problem sizes are known) +bool is_valid = ck::validation::RuntimeParameterValidator::Validate( + G, N, C, K, input_spatial_dims, output_spatial_dims); +``` + +--- + +## Parameter Selection Guidelines + +### Step 1: Choose Architecture-Specific Base Values +- **XDL**: MPerXDL=32, NPerXDL=32 (for fp16/bf16) or MPerXDL=16, NPerXDL=16 (for int8/fp8) +- **WMMA**: MPerWmma=16, NPerWmma=16 + +### Step 2: Determine Wave Configuration +- Calculate desired waves: MWaves × NWaves +- Ensure BlockSize = MWaves × NWaves × WaveSize + +### Step 3: Calculate Block Tile Sizes +- **XDL**: MPerBlock = MPerXDL × MXdlPerWave × MWaves +- **WMMA**: MPerBlock = MPerWmma × MRepeat × MWaves + +### Step 4: Choose K Decomposition +- Select AK1, BK1 (typically 4, 8, or 16) +- Ensure KPerBlock % AK1 == 0 and KPerBlock % BK1 == 0 + +### Step 5: Choose Shuffle Parameters +- CShuffleMXdlPerWavePerShuffle ≤ MXdlPerWave +- Ensure MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 + +### Step 6: Select Vector Transfer Widths +- Match to memory access pattern and data type size +- Ensure alignment with channel/output dimensions + +### Step 7: Validate Against All Rules +- Run compile-time validator +- Test with representative problem sizes